OLD | NEW |
(Empty) | |
| 1 from Cython.Compiler import TypeSlots |
| 2 from Cython.Compiler.ExprNodes import not_a_constant |
| 3 import cython |
| 4 cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object, |
| 5 Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object
, |
| 6 UtilNodes=object, Naming=object) |
| 7 |
| 8 import Nodes |
| 9 import ExprNodes |
| 10 import PyrexTypes |
| 11 import Visitor |
| 12 import Builtin |
| 13 import UtilNodes |
| 14 import Options |
| 15 import Naming |
| 16 |
| 17 from Code import UtilityCode |
| 18 from StringEncoding import EncodedString, BytesLiteral |
| 19 from Errors import error |
| 20 from ParseTreeTransforms import SkipDeclarations |
| 21 |
| 22 import copy |
| 23 import codecs |
| 24 |
| 25 try: |
| 26 from __builtin__ import reduce |
| 27 except ImportError: |
| 28 from functools import reduce |
| 29 |
| 30 try: |
| 31 from __builtin__ import basestring |
| 32 except ImportError: |
| 33 basestring = str # Python 3 |
| 34 |
| 35 def load_c_utility(name): |
| 36 return UtilityCode.load_cached(name, "Optimize.c") |
| 37 |
| 38 def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, Expr
Nodes.CoerceFromPyTypeNode)): |
| 39 if isinstance(node, coercion_nodes): |
| 40 return node.arg |
| 41 return node |
| 42 |
| 43 def unwrap_node(node): |
| 44 while isinstance(node, UtilNodes.ResultRefNode): |
| 45 node = node.expression |
| 46 return node |
| 47 |
| 48 def is_common_value(a, b): |
| 49 a = unwrap_node(a) |
| 50 b = unwrap_node(b) |
| 51 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode): |
| 52 return a.name == b.name |
| 53 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.Attrib
uteNode): |
| 54 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribut
e == b.attribute |
| 55 return False |
| 56 |
| 57 def filter_none_node(node): |
| 58 if node is not None and node.constant_result is None: |
| 59 return None |
| 60 return node |
| 61 |
| 62 class IterationTransform(Visitor.EnvTransform): |
| 63 """Transform some common for-in loop patterns into efficient C loops: |
| 64 |
| 65 - for-in-dict loop becomes a while loop calling PyDict_Next() |
| 66 - for-in-enumerate is replaced by an external counter variable |
| 67 - for-in-range loop becomes a plain C for loop |
| 68 """ |
| 69 def visit_PrimaryCmpNode(self, node): |
| 70 if node.is_ptr_contains(): |
| 71 |
| 72 # for t in operand2: |
| 73 # if operand1 == t: |
| 74 # res = True |
| 75 # break |
| 76 # else: |
| 77 # res = False |
| 78 |
| 79 pos = node.pos |
| 80 result_ref = UtilNodes.ResultRefNode(node) |
| 81 if isinstance(node.operand2, ExprNodes.IndexNode): |
| 82 base_type = node.operand2.base.type.base_type |
| 83 else: |
| 84 base_type = node.operand2.type.base_type |
| 85 target_handle = UtilNodes.TempHandle(base_type) |
| 86 target = target_handle.ref(pos) |
| 87 cmp_node = ExprNodes.PrimaryCmpNode( |
| 88 pos, operator=u'==', operand1=node.operand1, operand2=target) |
| 89 if_body = Nodes.StatListNode( |
| 90 pos, |
| 91 stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=Exp
rNodes.BoolNode(pos, value=1)), |
| 92 Nodes.BreakStatNode(pos)]) |
| 93 if_node = Nodes.IfStatNode( |
| 94 pos, |
| 95 if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_
body)], |
| 96 else_clause=None) |
| 97 for_loop = UtilNodes.TempsBlockNode( |
| 98 pos, |
| 99 temps = [target_handle], |
| 100 body = Nodes.ForInStatNode( |
| 101 pos, |
| 102 target=target, |
| 103 iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=
node.operand2), |
| 104 body=if_node, |
| 105 else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref,
rhs=ExprNodes.BoolNode(pos, value=0)))) |
| 106 for_loop = for_loop.analyse_expressions(self.current_env()) |
| 107 for_loop = self.visit(for_loop) |
| 108 new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop) |
| 109 |
| 110 if node.operator == 'not_in': |
| 111 new_node = ExprNodes.NotNode(pos, operand=new_node) |
| 112 return new_node |
| 113 |
| 114 else: |
| 115 self.visitchildren(node) |
| 116 return node |
| 117 |
| 118 def visit_ForInStatNode(self, node): |
| 119 self.visitchildren(node) |
| 120 return self._optimise_for_loop(node, node.iterator.sequence) |
| 121 |
| 122 def _optimise_for_loop(self, node, iterator, reversed=False): |
| 123 if iterator.type is Builtin.dict_type: |
| 124 # like iterating over dict.keys() |
| 125 if reversed: |
| 126 # CPython raises an error here: not a sequence |
| 127 return node |
| 128 return self._transform_dict_iteration( |
| 129 node, dict_obj=iterator, method=None, keys=True, values=False) |
| 130 |
| 131 # C array (slice) iteration? |
| 132 if iterator.type.is_ptr or iterator.type.is_array: |
| 133 return self._transform_carray_iteration(node, iterator, reversed=rev
ersed) |
| 134 if iterator.type is Builtin.bytes_type: |
| 135 return self._transform_bytes_iteration(node, iterator, reversed=reve
rsed) |
| 136 if iterator.type is Builtin.unicode_type: |
| 137 return self._transform_unicode_iteration(node, iterator, reversed=re
versed) |
| 138 |
| 139 # the rest is based on function calls |
| 140 if not isinstance(iterator, ExprNodes.SimpleCallNode): |
| 141 return node |
| 142 |
| 143 if iterator.args is None: |
| 144 arg_count = iterator.arg_tuple and len(iterator.arg_tuple.args) or 0 |
| 145 else: |
| 146 arg_count = len(iterator.args) |
| 147 if arg_count and iterator.self is not None: |
| 148 arg_count -= 1 |
| 149 |
| 150 function = iterator.function |
| 151 # dict iteration? |
| 152 if function.is_attribute and not reversed and not arg_count: |
| 153 base_obj = iterator.self or function.obj |
| 154 method = function.attribute |
| 155 # in Py3, items() is equivalent to Py2's iteritems() |
| 156 is_safe_iter = self.global_scope().context.language_level >= 3 |
| 157 |
| 158 if not is_safe_iter and method in ('keys', 'values', 'items'): |
| 159 # try to reduce this to the corresponding .iter*() methods |
| 160 if isinstance(base_obj, ExprNodes.SimpleCallNode): |
| 161 inner_function = base_obj.function |
| 162 if (inner_function.is_name and inner_function.name == 'dict' |
| 163 and inner_function.entry |
| 164 and inner_function.entry.is_builtin): |
| 165 # e.g. dict(something).items() => safe to use .iter*() |
| 166 is_safe_iter = True |
| 167 |
| 168 keys = values = False |
| 169 if method == 'iterkeys' or (is_safe_iter and method == 'keys'): |
| 170 keys = True |
| 171 elif method == 'itervalues' or (is_safe_iter and method == 'values')
: |
| 172 values = True |
| 173 elif method == 'iteritems' or (is_safe_iter and method == 'items'): |
| 174 keys = values = True |
| 175 |
| 176 if keys or values: |
| 177 return self._transform_dict_iteration( |
| 178 node, base_obj, method, keys, values) |
| 179 |
| 180 # enumerate/reversed ? |
| 181 if iterator.self is None and function.is_name and \ |
| 182 function.entry and function.entry.is_builtin: |
| 183 if function.name == 'enumerate': |
| 184 if reversed: |
| 185 # CPython raises an error here: not a sequence |
| 186 return node |
| 187 return self._transform_enumerate_iteration(node, iterator) |
| 188 elif function.name == 'reversed': |
| 189 if reversed: |
| 190 # CPython raises an error here: not a sequence |
| 191 return node |
| 192 return self._transform_reversed_iteration(node, iterator) |
| 193 |
| 194 # range() iteration? |
| 195 if Options.convert_range and node.target.type.is_int: |
| 196 if iterator.self is None and function.is_name and \ |
| 197 function.entry and function.entry.is_builtin and \ |
| 198 function.name in ('range', 'xrange'): |
| 199 return self._transform_range_iteration(node, iterator, reversed=
reversed) |
| 200 |
| 201 return node |
| 202 |
| 203 def _transform_reversed_iteration(self, node, reversed_function): |
| 204 args = reversed_function.arg_tuple.args |
| 205 if len(args) == 0: |
| 206 error(reversed_function.pos, |
| 207 "reversed() requires an iterable argument") |
| 208 return node |
| 209 elif len(args) > 1: |
| 210 error(reversed_function.pos, |
| 211 "reversed() takes exactly 1 argument") |
| 212 return node |
| 213 arg = args[0] |
| 214 |
| 215 # reversed(list/tuple) ? |
| 216 if arg.type in (Builtin.tuple_type, Builtin.list_type): |
| 217 node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is
not iterable") |
| 218 node.iterator.reversed = True |
| 219 return node |
| 220 |
| 221 return self._optimise_for_loop(node, arg, reversed=True) |
| 222 |
| 223 PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType( |
| 224 PyrexTypes.c_char_ptr_type, [ |
| 225 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) |
| 226 ]) |
| 227 |
| 228 PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType( |
| 229 PyrexTypes.c_py_ssize_t_type, [ |
| 230 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) |
| 231 ]) |
| 232 |
| 233 def _transform_bytes_iteration(self, node, slice_node, reversed=False): |
| 234 target_type = node.target.type |
| 235 if not target_type.is_int and target_type is not Builtin.bytes_type: |
| 236 # bytes iteration returns bytes objects in Py2, but |
| 237 # integers in Py3 |
| 238 return node |
| 239 |
| 240 unpack_temp_node = UtilNodes.LetRefNode( |
| 241 slice_node.as_none_safe_node("'NoneType' is not iterable")) |
| 242 |
| 243 slice_base_node = ExprNodes.PythonCapiCallNode( |
| 244 slice_node.pos, "PyBytes_AS_STRING", |
| 245 self.PyBytes_AS_STRING_func_type, |
| 246 args = [unpack_temp_node], |
| 247 is_temp = 0, |
| 248 ) |
| 249 len_node = ExprNodes.PythonCapiCallNode( |
| 250 slice_node.pos, "PyBytes_GET_SIZE", |
| 251 self.PyBytes_GET_SIZE_func_type, |
| 252 args = [unpack_temp_node], |
| 253 is_temp = 0, |
| 254 ) |
| 255 |
| 256 return UtilNodes.LetNode( |
| 257 unpack_temp_node, |
| 258 self._transform_carray_iteration( |
| 259 node, |
| 260 ExprNodes.SliceIndexNode( |
| 261 slice_node.pos, |
| 262 base = slice_base_node, |
| 263 start = None, |
| 264 step = None, |
| 265 stop = len_node, |
| 266 type = slice_base_node.type, |
| 267 is_temp = 1, |
| 268 ), |
| 269 reversed = reversed)) |
| 270 |
| 271 PyUnicode_READ_func_type = PyrexTypes.CFuncType( |
| 272 PyrexTypes.c_py_ucs4_type, [ |
| 273 PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None), |
| 274 PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None), |
| 275 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None) |
| 276 ]) |
| 277 |
| 278 init_unicode_iteration_func_type = PyrexTypes.CFuncType( |
| 279 PyrexTypes.c_int_type, [ |
| 280 PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None), |
| 281 PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type,
None), |
| 282 PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None
), |
| 283 PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None) |
| 284 ], |
| 285 exception_value = '-1') |
| 286 |
| 287 def _transform_unicode_iteration(self, node, slice_node, reversed=False): |
| 288 if slice_node.is_literal: |
| 289 # try to reduce to byte iteration for plain Latin-1 strings |
| 290 try: |
| 291 bytes_value = BytesLiteral(slice_node.value.encode('latin1')) |
| 292 except UnicodeEncodeError: |
| 293 pass |
| 294 else: |
| 295 bytes_slice = ExprNodes.SliceIndexNode( |
| 296 slice_node.pos, |
| 297 base=ExprNodes.BytesNode( |
| 298 slice_node.pos, value=bytes_value, |
| 299 constant_result=bytes_value, |
| 300 type=PyrexTypes.c_char_ptr_type).coerce_to( |
| 301 PyrexTypes.c_uchar_ptr_type, self.current_env()), |
| 302 start=None, |
| 303 stop=ExprNodes.IntNode( |
| 304 slice_node.pos, value=str(len(bytes_value)), |
| 305 constant_result=len(bytes_value), |
| 306 type=PyrexTypes.c_py_ssize_t_type), |
| 307 type=Builtin.unicode_type, # hint for Python conversion |
| 308 ) |
| 309 return self._transform_carray_iteration(node, bytes_slice, rever
sed) |
| 310 |
| 311 unpack_temp_node = UtilNodes.LetRefNode( |
| 312 slice_node.as_none_safe_node("'NoneType' is not iterable")) |
| 313 |
| 314 start_node = ExprNodes.IntNode( |
| 315 node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t
_type) |
| 316 length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) |
| 317 end_node = length_temp.ref(node.pos) |
| 318 if reversed: |
| 319 relation1, relation2 = '>', '>=' |
| 320 start_node, end_node = end_node, start_node |
| 321 else: |
| 322 relation1, relation2 = '<=', '<' |
| 323 |
| 324 kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type) |
| 325 data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type) |
| 326 counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) |
| 327 |
| 328 target_value = ExprNodes.PythonCapiCallNode( |
| 329 slice_node.pos, "__Pyx_PyUnicode_READ", |
| 330 self.PyUnicode_READ_func_type, |
| 331 args = [kind_temp.ref(slice_node.pos), |
| 332 data_temp.ref(slice_node.pos), |
| 333 counter_temp.ref(node.target.pos)], |
| 334 is_temp = False, |
| 335 ) |
| 336 if target_value.type != node.target.type: |
| 337 target_value = target_value.coerce_to(node.target.type, |
| 338 self.current_env()) |
| 339 target_assign = Nodes.SingleAssignmentNode( |
| 340 pos = node.target.pos, |
| 341 lhs = node.target, |
| 342 rhs = target_value) |
| 343 body = Nodes.StatListNode( |
| 344 node.pos, |
| 345 stats = [target_assign, node.body]) |
| 346 |
| 347 loop_node = Nodes.ForFromStatNode( |
| 348 node.pos, |
| 349 bound1=start_node, relation1=relation1, |
| 350 target=counter_temp.ref(node.target.pos), |
| 351 relation2=relation2, bound2=end_node, |
| 352 step=None, body=body, |
| 353 else_clause=node.else_clause, |
| 354 from_range=True) |
| 355 |
| 356 setup_node = Nodes.ExprStatNode( |
| 357 node.pos, |
| 358 expr = ExprNodes.PythonCapiCallNode( |
| 359 slice_node.pos, "__Pyx_init_unicode_iteration", |
| 360 self.init_unicode_iteration_func_type, |
| 361 args = [unpack_temp_node, |
| 362 ExprNodes.AmpersandNode(slice_node.pos, operand=length_t
emp.ref(slice_node.pos), |
| 363 type=PyrexTypes.c_py_ssize_t_ptr
_type), |
| 364 ExprNodes.AmpersandNode(slice_node.pos, operand=data_tem
p.ref(slice_node.pos), |
| 365 type=PyrexTypes.c_void_ptr_ptr_t
ype), |
| 366 ExprNodes.AmpersandNode(slice_node.pos, operand=kind_tem
p.ref(slice_node.pos), |
| 367 type=PyrexTypes.c_int_ptr_type), |
| 368 ], |
| 369 is_temp = True, |
| 370 result_is_used = False, |
| 371 utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c
"), |
| 372 )) |
| 373 return UtilNodes.LetNode( |
| 374 unpack_temp_node, |
| 375 UtilNodes.TempsBlockNode( |
| 376 node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp
], |
| 377 body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node])
)) |
| 378 |
| 379 def _transform_carray_iteration(self, node, slice_node, reversed=False): |
| 380 neg_step = False |
| 381 if isinstance(slice_node, ExprNodes.SliceIndexNode): |
| 382 slice_base = slice_node.base |
| 383 start = filter_none_node(slice_node.start) |
| 384 stop = filter_none_node(slice_node.stop) |
| 385 step = None |
| 386 if not stop: |
| 387 if not slice_base.type.is_pyobject: |
| 388 error(slice_node.pos, "C array iteration requires known end
index") |
| 389 return node |
| 390 |
| 391 elif isinstance(slice_node, ExprNodes.IndexNode): |
| 392 assert isinstance(slice_node.index, ExprNodes.SliceNode) |
| 393 slice_base = slice_node.base |
| 394 index = slice_node.index |
| 395 start = filter_none_node(index.start) |
| 396 stop = filter_none_node(index.stop) |
| 397 step = filter_none_node(index.step) |
| 398 if step: |
| 399 if not isinstance(step.constant_result, (int,long)) \ |
| 400 or step.constant_result == 0 \ |
| 401 or step.constant_result > 0 and not stop \ |
| 402 or step.constant_result < 0 and not start: |
| 403 if not slice_base.type.is_pyobject: |
| 404 error(step.pos, "C array iteration requires known step s
ize and end index") |
| 405 return node |
| 406 else: |
| 407 # step sign is handled internally by ForFromStatNode |
| 408 step_value = step.constant_result |
| 409 if reversed: |
| 410 step_value = -step_value |
| 411 neg_step = step_value < 0 |
| 412 step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssiz
e_t_type, |
| 413 value=str(abs(step_value)), |
| 414 constant_result=abs(step_value)) |
| 415 |
| 416 elif slice_node.type.is_array: |
| 417 if slice_node.type.size is None: |
| 418 error(slice_node.pos, "C array iteration requires known end inde
x") |
| 419 return node |
| 420 slice_base = slice_node |
| 421 start = None |
| 422 stop = ExprNodes.IntNode( |
| 423 slice_node.pos, value=str(slice_node.type.size), |
| 424 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.ty
pe.size) |
| 425 step = None |
| 426 |
| 427 else: |
| 428 if not slice_node.type.is_pyobject: |
| 429 error(slice_node.pos, "C array iteration requires known end inde
x") |
| 430 return node |
| 431 |
| 432 if start: |
| 433 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_e
nv()) |
| 434 if stop: |
| 435 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env
()) |
| 436 if stop is None: |
| 437 if neg_step: |
| 438 stop = ExprNodes.IntNode( |
| 439 slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_typ
e, constant_result=-1) |
| 440 else: |
| 441 error(slice_node.pos, "C array iteration requires known step siz
e and end index") |
| 442 return node |
| 443 |
| 444 if reversed: |
| 445 if not start: |
| 446 start = ExprNodes.IntNode(slice_node.pos, value="0", constant_r
esult=0, |
| 447 type=PyrexTypes.c_py_ssize_t_type) |
| 448 # if step was provided, it was already negated above |
| 449 start, stop = stop, start |
| 450 |
| 451 ptr_type = slice_base.type |
| 452 if ptr_type.is_array: |
| 453 ptr_type = ptr_type.element_ptr_type() |
| 454 carray_ptr = slice_base.coerce_to_simple(self.current_env()) |
| 455 |
| 456 if start and start.constant_result != 0: |
| 457 start_ptr_node = ExprNodes.AddNode( |
| 458 start.pos, |
| 459 operand1=carray_ptr, |
| 460 operator='+', |
| 461 operand2=start, |
| 462 type=ptr_type) |
| 463 else: |
| 464 start_ptr_node = carray_ptr |
| 465 |
| 466 if stop and stop.constant_result != 0: |
| 467 stop_ptr_node = ExprNodes.AddNode( |
| 468 stop.pos, |
| 469 operand1=ExprNodes.CloneNode(carray_ptr), |
| 470 operator='+', |
| 471 operand2=stop, |
| 472 type=ptr_type |
| 473 ).coerce_to_simple(self.current_env()) |
| 474 else: |
| 475 stop_ptr_node = ExprNodes.CloneNode(carray_ptr) |
| 476 |
| 477 counter = UtilNodes.TempHandle(ptr_type) |
| 478 counter_temp = counter.ref(node.target.pos) |
| 479 |
| 480 if slice_base.type.is_string and node.target.type.is_pyobject: |
| 481 # special case: char* -> bytes/unicode |
| 482 if slice_node.type is Builtin.unicode_type: |
| 483 target_value = ExprNodes.CastNode( |
| 484 ExprNodes.DereferenceNode( |
| 485 node.target.pos, operand=counter_temp, |
| 486 type=ptr_type.base_type), |
| 487 PyrexTypes.c_py_ucs4_type).coerce_to( |
| 488 node.target.type, self.current_env()) |
| 489 else: |
| 490 # char* -> bytes coercion requires slicing, not indexing |
| 491 target_value = ExprNodes.SliceIndexNode( |
| 492 node.target.pos, |
| 493 start=ExprNodes.IntNode(node.target.pos, value='0', |
| 494 constant_result=0, |
| 495 type=PyrexTypes.c_int_type), |
| 496 stop=ExprNodes.IntNode(node.target.pos, value='1', |
| 497 constant_result=1, |
| 498 type=PyrexTypes.c_int_type), |
| 499 base=counter_temp, |
| 500 type=Builtin.bytes_type, |
| 501 is_temp=1) |
| 502 elif node.target.type.is_ptr and not node.target.type.assignable_from(pt
r_type.base_type): |
| 503 # Allow iteration with pointer target to avoid copy. |
| 504 target_value = counter_temp |
| 505 else: |
| 506 # TODO: can this safely be replaced with DereferenceNode() as above? |
| 507 target_value = ExprNodes.IndexNode( |
| 508 node.target.pos, |
| 509 index=ExprNodes.IntNode(node.target.pos, value='0', |
| 510 constant_result=0, |
| 511 type=PyrexTypes.c_int_type), |
| 512 base=counter_temp, |
| 513 is_buffer_access=False, |
| 514 type=ptr_type.base_type) |
| 515 |
| 516 if target_value.type != node.target.type: |
| 517 target_value = target_value.coerce_to(node.target.type, |
| 518 self.current_env()) |
| 519 |
| 520 target_assign = Nodes.SingleAssignmentNode( |
| 521 pos = node.target.pos, |
| 522 lhs = node.target, |
| 523 rhs = target_value) |
| 524 |
| 525 body = Nodes.StatListNode( |
| 526 node.pos, |
| 527 stats = [target_assign, node.body]) |
| 528 |
| 529 relation1, relation2 = self._find_for_from_node_relations(neg_step, reve
rsed) |
| 530 |
| 531 for_node = Nodes.ForFromStatNode( |
| 532 node.pos, |
| 533 bound1=start_ptr_node, relation1=relation1, |
| 534 target=counter_temp, |
| 535 relation2=relation2, bound2=stop_ptr_node, |
| 536 step=step, body=body, |
| 537 else_clause=node.else_clause, |
| 538 from_range=True) |
| 539 |
| 540 return UtilNodes.TempsBlockNode( |
| 541 node.pos, temps=[counter], |
| 542 body=for_node) |
| 543 |
| 544 def _transform_enumerate_iteration(self, node, enumerate_function): |
| 545 args = enumerate_function.arg_tuple.args |
| 546 if len(args) == 0: |
| 547 error(enumerate_function.pos, |
| 548 "enumerate() requires an iterable argument") |
| 549 return node |
| 550 elif len(args) > 2: |
| 551 error(enumerate_function.pos, |
| 552 "enumerate() takes at most 2 arguments") |
| 553 return node |
| 554 |
| 555 if not node.target.is_sequence_constructor: |
| 556 # leave this untouched for now |
| 557 return node |
| 558 targets = node.target.args |
| 559 if len(targets) != 2: |
| 560 # leave this untouched for now |
| 561 return node |
| 562 |
| 563 enumerate_target, iterable_target = targets |
| 564 counter_type = enumerate_target.type |
| 565 |
| 566 if not counter_type.is_pyobject and not counter_type.is_int: |
| 567 # nothing we can do here, I guess |
| 568 return node |
| 569 |
| 570 if len(args) == 2: |
| 571 start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.cu
rrent_env()) |
| 572 else: |
| 573 start = ExprNodes.IntNode(enumerate_function.pos, |
| 574 value='0', |
| 575 type=counter_type, |
| 576 constant_result=0) |
| 577 temp = UtilNodes.LetRefNode(start) |
| 578 |
| 579 inc_expression = ExprNodes.AddNode( |
| 580 enumerate_function.pos, |
| 581 operand1 = temp, |
| 582 operand2 = ExprNodes.IntNode(node.pos, value='1', |
| 583 type=counter_type, |
| 584 constant_result=1), |
| 585 operator = '+', |
| 586 type = counter_type, |
| 587 #inplace = True, # not worth using in-place operation for Py ints |
| 588 is_temp = counter_type.is_pyobject |
| 589 ) |
| 590 |
| 591 loop_body = [ |
| 592 Nodes.SingleAssignmentNode( |
| 593 pos = enumerate_target.pos, |
| 594 lhs = enumerate_target, |
| 595 rhs = temp), |
| 596 Nodes.SingleAssignmentNode( |
| 597 pos = enumerate_target.pos, |
| 598 lhs = temp, |
| 599 rhs = inc_expression) |
| 600 ] |
| 601 |
| 602 if isinstance(node.body, Nodes.StatListNode): |
| 603 node.body.stats = loop_body + node.body.stats |
| 604 else: |
| 605 loop_body.append(node.body) |
| 606 node.body = Nodes.StatListNode( |
| 607 node.body.pos, |
| 608 stats = loop_body) |
| 609 |
| 610 node.target = iterable_target |
| 611 node.item = node.item.coerce_to(iterable_target.type, self.current_env()
) |
| 612 node.iterator.sequence = args[0] |
| 613 |
| 614 # recurse into loop to check for further optimisations |
| 615 return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterat
or.sequence)) |
| 616 |
| 617 def _find_for_from_node_relations(self, neg_step_value, reversed): |
| 618 if reversed: |
| 619 if neg_step_value: |
| 620 return '<', '<=' |
| 621 else: |
| 622 return '>', '>=' |
| 623 else: |
| 624 if neg_step_value: |
| 625 return '>=', '>' |
| 626 else: |
| 627 return '<=', '<' |
| 628 |
| 629 def _transform_range_iteration(self, node, range_function, reversed=False): |
| 630 args = range_function.arg_tuple.args |
| 631 if len(args) < 3: |
| 632 step_pos = range_function.pos |
| 633 step_value = 1 |
| 634 step = ExprNodes.IntNode(step_pos, value='1', |
| 635 constant_result=1) |
| 636 else: |
| 637 step = args[2] |
| 638 step_pos = step.pos |
| 639 if not isinstance(step.constant_result, (int, long)): |
| 640 # cannot determine step direction |
| 641 return node |
| 642 step_value = step.constant_result |
| 643 if step_value == 0: |
| 644 # will lead to an error elsewhere |
| 645 return node |
| 646 if reversed and step_value not in (1, -1): |
| 647 # FIXME: currently broken - requires calculation of the correct
bounds |
| 648 return node |
| 649 if not isinstance(step, ExprNodes.IntNode): |
| 650 step = ExprNodes.IntNode(step_pos, value=str(step_value), |
| 651 constant_result=step_value) |
| 652 |
| 653 if len(args) == 1: |
| 654 bound1 = ExprNodes.IntNode(range_function.pos, value='0', |
| 655 constant_result=0) |
| 656 bound2 = args[0].coerce_to_integer(self.current_env()) |
| 657 else: |
| 658 bound1 = args[0].coerce_to_integer(self.current_env()) |
| 659 bound2 = args[1].coerce_to_integer(self.current_env()) |
| 660 |
| 661 relation1, relation2 = self._find_for_from_node_relations(step_value < 0
, reversed) |
| 662 |
| 663 if reversed: |
| 664 bound1, bound2 = bound2, bound1 |
| 665 if step_value < 0: |
| 666 step_value = -step_value |
| 667 else: |
| 668 if step_value < 0: |
| 669 step_value = -step_value |
| 670 |
| 671 step.value = str(step_value) |
| 672 step.constant_result = step_value |
| 673 step = step.coerce_to_integer(self.current_env()) |
| 674 |
| 675 if not bound2.is_literal: |
| 676 # stop bound must be immutable => keep it in a temp var |
| 677 bound2_is_temp = True |
| 678 bound2 = UtilNodes.LetRefNode(bound2) |
| 679 else: |
| 680 bound2_is_temp = False |
| 681 |
| 682 for_node = Nodes.ForFromStatNode( |
| 683 node.pos, |
| 684 target=node.target, |
| 685 bound1=bound1, relation1=relation1, |
| 686 relation2=relation2, bound2=bound2, |
| 687 step=step, body=node.body, |
| 688 else_clause=node.else_clause, |
| 689 from_range=True) |
| 690 |
| 691 if bound2_is_temp: |
| 692 for_node = UtilNodes.LetNode(bound2, for_node) |
| 693 |
| 694 return for_node |
| 695 |
| 696 def _transform_dict_iteration(self, node, dict_obj, method, keys, values): |
| 697 temps = [] |
| 698 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type) |
| 699 temps.append(temp) |
| 700 dict_temp = temp.ref(dict_obj.pos) |
| 701 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) |
| 702 temps.append(temp) |
| 703 pos_temp = temp.ref(node.pos) |
| 704 |
| 705 key_target = value_target = tuple_target = None |
| 706 if keys and values: |
| 707 if node.target.is_sequence_constructor: |
| 708 if len(node.target.args) == 2: |
| 709 key_target, value_target = node.target.args |
| 710 else: |
| 711 # unusual case that may or may not lead to an error |
| 712 return node |
| 713 else: |
| 714 tuple_target = node.target |
| 715 elif keys: |
| 716 key_target = node.target |
| 717 else: |
| 718 value_target = node.target |
| 719 |
| 720 if isinstance(node.body, Nodes.StatListNode): |
| 721 body = node.body |
| 722 else: |
| 723 body = Nodes.StatListNode(pos = node.body.pos, |
| 724 stats = [node.body]) |
| 725 |
| 726 # keep original length to guard against dict modification |
| 727 dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) |
| 728 temps.append(dict_len_temp) |
| 729 dict_len_temp_addr = ExprNodes.AmpersandNode( |
| 730 node.pos, operand=dict_len_temp.ref(dict_obj.pos), |
| 731 type=PyrexTypes.c_ptr_type(dict_len_temp.type)) |
| 732 temp = UtilNodes.TempHandle(PyrexTypes.c_int_type) |
| 733 temps.append(temp) |
| 734 is_dict_temp = temp.ref(node.pos) |
| 735 is_dict_temp_addr = ExprNodes.AmpersandNode( |
| 736 node.pos, operand=is_dict_temp, |
| 737 type=PyrexTypes.c_ptr_type(temp.type)) |
| 738 |
| 739 iter_next_node = Nodes.DictIterationNextNode( |
| 740 dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp, |
| 741 key_target, value_target, tuple_target, |
| 742 is_dict_temp) |
| 743 iter_next_node = iter_next_node.analyse_expressions(self.current_env()) |
| 744 body.stats[0:0] = [iter_next_node] |
| 745 |
| 746 if method: |
| 747 method_node = ExprNodes.StringNode( |
| 748 dict_obj.pos, is_identifier=True, value=method) |
| 749 dict_obj = dict_obj.as_none_safe_node( |
| 750 "'NoneType' object has no attribute '%s'", |
| 751 error = "PyExc_AttributeError", |
| 752 format_args = [method]) |
| 753 else: |
| 754 method_node = ExprNodes.NullNode(dict_obj.pos) |
| 755 dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iter
able") |
| 756 |
| 757 def flag_node(value): |
| 758 value = value and 1 or 0 |
| 759 return ExprNodes.IntNode(node.pos, value=str(value), constant_result
=value) |
| 760 |
| 761 result_code = [ |
| 762 Nodes.SingleAssignmentNode( |
| 763 node.pos, |
| 764 lhs = pos_temp, |
| 765 rhs = ExprNodes.IntNode(node.pos, value='0', |
| 766 constant_result=0)), |
| 767 Nodes.SingleAssignmentNode( |
| 768 dict_obj.pos, |
| 769 lhs = dict_temp, |
| 770 rhs = ExprNodes.PythonCapiCallNode( |
| 771 dict_obj.pos, |
| 772 "__Pyx_dict_iterator", |
| 773 self.PyDict_Iterator_func_type, |
| 774 utility_code = UtilityCode.load_cached("dict_iter", "Optimiz
e.c"), |
| 775 args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_ty
pe), |
| 776 method_node, dict_len_temp_addr, is_dict_temp_addr, |
| 777 ], |
| 778 is_temp=True, |
| 779 )), |
| 780 Nodes.WhileStatNode( |
| 781 node.pos, |
| 782 condition = None, |
| 783 body = body, |
| 784 else_clause = node.else_clause |
| 785 ) |
| 786 ] |
| 787 |
| 788 return UtilNodes.TempsBlockNode( |
| 789 node.pos, temps=temps, |
| 790 body=Nodes.StatListNode( |
| 791 node.pos, |
| 792 stats = result_code |
| 793 )) |
| 794 |
| 795 PyDict_Iterator_func_type = PyrexTypes.CFuncType( |
| 796 PyrexTypes.py_object_type, [ |
| 797 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), |
| 798 PyrexTypes.CFuncTypeArg("is_dict", PyrexTypes.c_int_type, None), |
| 799 PyrexTypes.CFuncTypeArg("method_name", PyrexTypes.py_object_type, N
one), |
| 800 PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_pt
r_type, None), |
| 801 PyrexTypes.CFuncTypeArg("p_is_dict", PyrexTypes.c_int_ptr_type, Non
e), |
| 802 ]) |
| 803 |
| 804 |
| 805 class SwitchTransform(Visitor.VisitorTransform): |
| 806 """ |
| 807 This transformation tries to turn long if statements into C switch statement
s. |
| 808 The requirement is that every clause be an (or of) var == value, where the v
ar |
| 809 is common among all clauses and both var and value are ints. |
| 810 """ |
| 811 NO_MATCH = (None, None, None) |
| 812 |
| 813 def extract_conditions(self, cond, allow_not_in): |
| 814 while True: |
| 815 if isinstance(cond, (ExprNodes.CoerceToTempNode, |
| 816 ExprNodes.CoerceToBooleanNode)): |
| 817 cond = cond.arg |
| 818 elif isinstance(cond, UtilNodes.EvalWithTempExprNode): |
| 819 # this is what we get from the FlattenInListTransform |
| 820 cond = cond.subexpression |
| 821 elif isinstance(cond, ExprNodes.TypecastNode): |
| 822 cond = cond.operand |
| 823 else: |
| 824 break |
| 825 |
| 826 if isinstance(cond, ExprNodes.PrimaryCmpNode): |
| 827 if cond.cascade is not None: |
| 828 return self.NO_MATCH |
| 829 elif cond.is_c_string_contains() and \ |
| 830 isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.B
ytesNode)): |
| 831 not_in = cond.operator == 'not_in' |
| 832 if not_in and not allow_not_in: |
| 833 return self.NO_MATCH |
| 834 if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \ |
| 835 cond.operand2.contains_surrogates(): |
| 836 # dealing with surrogates leads to different |
| 837 # behaviour on wide and narrow Unicode |
| 838 # platforms => refuse to optimise this case |
| 839 return self.NO_MATCH |
| 840 return not_in, cond.operand1, self.extract_in_string_conditions(
cond.operand2) |
| 841 elif not cond.is_python_comparison(): |
| 842 if cond.operator == '==': |
| 843 not_in = False |
| 844 elif allow_not_in and cond.operator == '!=': |
| 845 not_in = True |
| 846 else: |
| 847 return self.NO_MATCH |
| 848 # this looks somewhat silly, but it does the right |
| 849 # checks for NameNode and AttributeNode |
| 850 if is_common_value(cond.operand1, cond.operand1): |
| 851 if cond.operand2.is_literal: |
| 852 return not_in, cond.operand1, [cond.operand2] |
| 853 elif getattr(cond.operand2, 'entry', None) \ |
| 854 and cond.operand2.entry.is_const: |
| 855 return not_in, cond.operand1, [cond.operand2] |
| 856 if is_common_value(cond.operand2, cond.operand2): |
| 857 if cond.operand1.is_literal: |
| 858 return not_in, cond.operand2, [cond.operand1] |
| 859 elif getattr(cond.operand1, 'entry', None) \ |
| 860 and cond.operand1.entry.is_const: |
| 861 return not_in, cond.operand2, [cond.operand1] |
| 862 elif isinstance(cond, ExprNodes.BoolBinopNode): |
| 863 if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'
): |
| 864 allow_not_in = (cond.operator == 'and') |
| 865 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_
not_in) |
| 866 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_
not_in) |
| 867 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t
1, t2): |
| 868 if (not not_in_1) or allow_not_in: |
| 869 return not_in_1, t1, c1+c2 |
| 870 return self.NO_MATCH |
| 871 |
| 872 def extract_in_string_conditions(self, string_literal): |
| 873 if isinstance(string_literal, ExprNodes.UnicodeNode): |
| 874 charvals = list(map(ord, set(string_literal.value))) |
| 875 charvals.sort() |
| 876 return [ ExprNodes.IntNode(string_literal.pos, value=str(charval), |
| 877 constant_result=charval) |
| 878 for charval in charvals ] |
| 879 else: |
| 880 # this is a bit tricky as Py3's bytes type returns |
| 881 # integers on iteration, whereas Py2 returns 1-char byte |
| 882 # strings |
| 883 characters = string_literal.value |
| 884 characters = list(set([ characters[i:i+1] for i in range(len(charact
ers)) ])) |
| 885 characters.sort() |
| 886 return [ ExprNodes.CharNode(string_literal.pos, value=charval, |
| 887 constant_result=charval) |
| 888 for charval in characters ] |
| 889 |
| 890 def extract_common_conditions(self, common_var, condition, allow_not_in): |
| 891 not_in, var, conditions = self.extract_conditions(condition, allow_not_i
n) |
| 892 if var is None: |
| 893 return self.NO_MATCH |
| 894 elif common_var is not None and not is_common_value(var, common_var): |
| 895 return self.NO_MATCH |
| 896 elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is
_int or cond.type.is_enum) for cond in conditions]): |
| 897 return self.NO_MATCH |
| 898 return not_in, var, conditions |
| 899 |
| 900 def has_duplicate_values(self, condition_values): |
| 901 # duplicated values don't work in a switch statement |
| 902 seen = set() |
| 903 for value in condition_values: |
| 904 if value.has_constant_result(): |
| 905 if value.constant_result in seen: |
| 906 return True |
| 907 seen.add(value.constant_result) |
| 908 else: |
| 909 # this isn't completely safe as we don't know the |
| 910 # final C value, but this is about the best we can do |
| 911 try: |
| 912 if value.entry.cname in seen: |
| 913 return True |
| 914 except AttributeError: |
| 915 return True # play safe |
| 916 seen.add(value.entry.cname) |
| 917 return False |
| 918 |
| 919 def visit_IfStatNode(self, node): |
| 920 common_var = None |
| 921 cases = [] |
| 922 for if_clause in node.if_clauses: |
| 923 _, common_var, conditions = self.extract_common_conditions( |
| 924 common_var, if_clause.condition, False) |
| 925 if common_var is None: |
| 926 self.visitchildren(node) |
| 927 return node |
| 928 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos, |
| 929 conditions = conditions, |
| 930 body = if_clause.body)) |
| 931 |
| 932 condition_values = [ |
| 933 cond for case in cases for cond in case.conditions] |
| 934 if len(condition_values) < 2: |
| 935 self.visitchildren(node) |
| 936 return node |
| 937 if self.has_duplicate_values(condition_values): |
| 938 self.visitchildren(node) |
| 939 return node |
| 940 |
| 941 common_var = unwrap_node(common_var) |
| 942 switch_node = Nodes.SwitchStatNode(pos = node.pos, |
| 943 test = common_var, |
| 944 cases = cases, |
| 945 else_clause = node.else_clause) |
| 946 return switch_node |
| 947 |
| 948 def visit_CondExprNode(self, node): |
| 949 not_in, common_var, conditions = self.extract_common_conditions( |
| 950 None, node.test, True) |
| 951 if common_var is None \ |
| 952 or len(conditions) < 2 \ |
| 953 or self.has_duplicate_values(conditions): |
| 954 self.visitchildren(node) |
| 955 return node |
| 956 return self.build_simple_switch_statement( |
| 957 node, common_var, conditions, not_in, |
| 958 node.true_val, node.false_val) |
| 959 |
| 960 def visit_BoolBinopNode(self, node): |
| 961 not_in, common_var, conditions = self.extract_common_conditions( |
| 962 None, node, True) |
| 963 if common_var is None \ |
| 964 or len(conditions) < 2 \ |
| 965 or self.has_duplicate_values(conditions): |
| 966 self.visitchildren(node) |
| 967 return node |
| 968 |
| 969 return self.build_simple_switch_statement( |
| 970 node, common_var, conditions, not_in, |
| 971 ExprNodes.BoolNode(node.pos, value=True, constant_result=True), |
| 972 ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) |
| 973 |
| 974 def visit_PrimaryCmpNode(self, node): |
| 975 not_in, common_var, conditions = self.extract_common_conditions( |
| 976 None, node, True) |
| 977 if common_var is None \ |
| 978 or len(conditions) < 2 \ |
| 979 or self.has_duplicate_values(conditions): |
| 980 self.visitchildren(node) |
| 981 return node |
| 982 |
| 983 return self.build_simple_switch_statement( |
| 984 node, common_var, conditions, not_in, |
| 985 ExprNodes.BoolNode(node.pos, value=True, constant_result=True), |
| 986 ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) |
| 987 |
| 988 def build_simple_switch_statement(self, node, common_var, conditions, |
| 989 not_in, true_val, false_val): |
| 990 result_ref = UtilNodes.ResultRefNode(node) |
| 991 true_body = Nodes.SingleAssignmentNode( |
| 992 node.pos, |
| 993 lhs = result_ref, |
| 994 rhs = true_val, |
| 995 first = True) |
| 996 false_body = Nodes.SingleAssignmentNode( |
| 997 node.pos, |
| 998 lhs = result_ref, |
| 999 rhs = false_val, |
| 1000 first = True) |
| 1001 |
| 1002 if not_in: |
| 1003 true_body, false_body = false_body, true_body |
| 1004 |
| 1005 cases = [Nodes.SwitchCaseNode(pos = node.pos, |
| 1006 conditions = conditions, |
| 1007 body = true_body)] |
| 1008 |
| 1009 common_var = unwrap_node(common_var) |
| 1010 switch_node = Nodes.SwitchStatNode(pos = node.pos, |
| 1011 test = common_var, |
| 1012 cases = cases, |
| 1013 else_clause = false_body) |
| 1014 replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node) |
| 1015 return replacement |
| 1016 |
| 1017 def visit_EvalWithTempExprNode(self, node): |
| 1018 # drop unused expression temp from FlattenInListTransform |
| 1019 orig_expr = node.subexpression |
| 1020 temp_ref = node.lazy_temp |
| 1021 self.visitchildren(node) |
| 1022 if node.subexpression is not orig_expr: |
| 1023 # node was restructured => check if temp is still used |
| 1024 if not Visitor.tree_contains(node.subexpression, temp_ref): |
| 1025 return node.subexpression |
| 1026 return node |
| 1027 |
| 1028 visit_Node = Visitor.VisitorTransform.recurse_to_children |
| 1029 |
| 1030 |
| 1031 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations): |
| 1032 """ |
| 1033 This transformation flattens "x in [val1, ..., valn]" into a sequential list |
| 1034 of comparisons. |
| 1035 """ |
| 1036 |
| 1037 def visit_PrimaryCmpNode(self, node): |
| 1038 self.visitchildren(node) |
| 1039 if node.cascade is not None: |
| 1040 return node |
| 1041 elif node.operator == 'in': |
| 1042 conjunction = 'or' |
| 1043 eq_or_neq = '==' |
| 1044 elif node.operator == 'not_in': |
| 1045 conjunction = 'and' |
| 1046 eq_or_neq = '!=' |
| 1047 else: |
| 1048 return node |
| 1049 |
| 1050 if not isinstance(node.operand2, (ExprNodes.TupleNode, |
| 1051 ExprNodes.ListNode, |
| 1052 ExprNodes.SetNode)): |
| 1053 return node |
| 1054 |
| 1055 args = node.operand2.args |
| 1056 if len(args) == 0: |
| 1057 # note: lhs may have side effects |
| 1058 return node |
| 1059 |
| 1060 lhs = UtilNodes.ResultRefNode(node.operand1) |
| 1061 |
| 1062 conds = [] |
| 1063 temps = [] |
| 1064 for arg in args: |
| 1065 try: |
| 1066 # Trial optimisation to avoid redundant temp |
| 1067 # assignments. However, since is_simple() is meant to |
| 1068 # be called after type analysis, we ignore any errors |
| 1069 # and just play safe in that case. |
| 1070 is_simple_arg = arg.is_simple() |
| 1071 except Exception: |
| 1072 is_simple_arg = False |
| 1073 if not is_simple_arg: |
| 1074 # must evaluate all non-simple RHS before doing the comparisons |
| 1075 arg = UtilNodes.LetRefNode(arg) |
| 1076 temps.append(arg) |
| 1077 cond = ExprNodes.PrimaryCmpNode( |
| 1078 pos = node.pos, |
| 1079 operand1 = lhs, |
| 1080 operator = eq_or_neq, |
| 1081 operand2 = arg, |
| 1082 cascade = None) |
| 1083 conds.append(ExprNodes.TypecastNode( |
| 1084 pos = node.pos, |
| 1085 operand = cond, |
| 1086 type = PyrexTypes.c_bint_type)) |
| 1087 def concat(left, right): |
| 1088 return ExprNodes.BoolBinopNode( |
| 1089 pos = node.pos, |
| 1090 operator = conjunction, |
| 1091 operand1 = left, |
| 1092 operand2 = right) |
| 1093 |
| 1094 condition = reduce(concat, conds) |
| 1095 new_node = UtilNodes.EvalWithTempExprNode(lhs, condition) |
| 1096 for temp in temps[::-1]: |
| 1097 new_node = UtilNodes.EvalWithTempExprNode(temp, new_node) |
| 1098 return new_node |
| 1099 |
| 1100 visit_Node = Visitor.VisitorTransform.recurse_to_children |
| 1101 |
| 1102 |
| 1103 class DropRefcountingTransform(Visitor.VisitorTransform): |
| 1104 """Drop ref-counting in safe places. |
| 1105 """ |
| 1106 visit_Node = Visitor.VisitorTransform.recurse_to_children |
| 1107 |
| 1108 def visit_ParallelAssignmentNode(self, node): |
| 1109 """ |
| 1110 Parallel swap assignments like 'a,b = b,a' are safe. |
| 1111 """ |
| 1112 left_names, right_names = [], [] |
| 1113 left_indices, right_indices = [], [] |
| 1114 temps = [] |
| 1115 |
| 1116 for stat in node.stats: |
| 1117 if isinstance(stat, Nodes.SingleAssignmentNode): |
| 1118 if not self._extract_operand(stat.lhs, left_names, |
| 1119 left_indices, temps): |
| 1120 return node |
| 1121 if not self._extract_operand(stat.rhs, right_names, |
| 1122 right_indices, temps): |
| 1123 return node |
| 1124 elif isinstance(stat, Nodes.CascadedAssignmentNode): |
| 1125 # FIXME |
| 1126 return node |
| 1127 else: |
| 1128 return node |
| 1129 |
| 1130 if left_names or right_names: |
| 1131 # lhs/rhs names must be a non-redundant permutation |
| 1132 lnames = [ path for path, n in left_names ] |
| 1133 rnames = [ path for path, n in right_names ] |
| 1134 if set(lnames) != set(rnames): |
| 1135 return node |
| 1136 if len(set(lnames)) != len(right_names): |
| 1137 return node |
| 1138 |
| 1139 if left_indices or right_indices: |
| 1140 # base name and index of index nodes must be a |
| 1141 # non-redundant permutation |
| 1142 lindices = [] |
| 1143 for lhs_node in left_indices: |
| 1144 index_id = self._extract_index_id(lhs_node) |
| 1145 if not index_id: |
| 1146 return node |
| 1147 lindices.append(index_id) |
| 1148 rindices = [] |
| 1149 for rhs_node in right_indices: |
| 1150 index_id = self._extract_index_id(rhs_node) |
| 1151 if not index_id: |
| 1152 return node |
| 1153 rindices.append(index_id) |
| 1154 |
| 1155 if set(lindices) != set(rindices): |
| 1156 return node |
| 1157 if len(set(lindices)) != len(right_indices): |
| 1158 return node |
| 1159 |
| 1160 # really supporting IndexNode requires support in |
| 1161 # __Pyx_GetItemInt(), so let's stop short for now |
| 1162 return node |
| 1163 |
| 1164 temp_args = [t.arg for t in temps] |
| 1165 for temp in temps: |
| 1166 temp.use_managed_ref = False |
| 1167 |
| 1168 for _, name_node in left_names + right_names: |
| 1169 if name_node not in temp_args: |
| 1170 name_node.use_managed_ref = False |
| 1171 |
| 1172 for index_node in left_indices + right_indices: |
| 1173 index_node.use_managed_ref = False |
| 1174 |
| 1175 return node |
| 1176 |
| 1177 def _extract_operand(self, node, names, indices, temps): |
| 1178 node = unwrap_node(node) |
| 1179 if not node.type.is_pyobject: |
| 1180 return False |
| 1181 if isinstance(node, ExprNodes.CoerceToTempNode): |
| 1182 temps.append(node) |
| 1183 node = node.arg |
| 1184 name_path = [] |
| 1185 obj_node = node |
| 1186 while isinstance(obj_node, ExprNodes.AttributeNode): |
| 1187 if obj_node.is_py_attr: |
| 1188 return False |
| 1189 name_path.append(obj_node.member) |
| 1190 obj_node = obj_node.obj |
| 1191 if isinstance(obj_node, ExprNodes.NameNode): |
| 1192 name_path.append(obj_node.name) |
| 1193 names.append( ('.'.join(name_path[::-1]), node) ) |
| 1194 elif isinstance(node, ExprNodes.IndexNode): |
| 1195 if node.base.type != Builtin.list_type: |
| 1196 return False |
| 1197 if not node.index.type.is_int: |
| 1198 return False |
| 1199 if not isinstance(node.base, ExprNodes.NameNode): |
| 1200 return False |
| 1201 indices.append(node) |
| 1202 else: |
| 1203 return False |
| 1204 return True |
| 1205 |
| 1206 def _extract_index_id(self, index_node): |
| 1207 base = index_node.base |
| 1208 index = index_node.index |
| 1209 if isinstance(index, ExprNodes.NameNode): |
| 1210 index_val = index.name |
| 1211 elif isinstance(index, ExprNodes.ConstNode): |
| 1212 # FIXME: |
| 1213 return None |
| 1214 else: |
| 1215 return None |
| 1216 return (base.name, index_val) |
| 1217 |
| 1218 |
| 1219 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): |
| 1220 """Optimize some common calls to builtin types *before* the type |
| 1221 analysis phase and *after* the declarations analysis phase. |
| 1222 |
| 1223 This transform cannot make use of any argument types, but it can |
| 1224 restructure the tree in a way that the type analysis phase can |
| 1225 respond to. |
| 1226 |
| 1227 Introducing C function calls here may not be a good idea. Move |
| 1228 them to the OptimizeBuiltinCalls transform instead, which runs |
| 1229 after type analysis. |
| 1230 """ |
| 1231 # only intercept on call nodes |
| 1232 visit_Node = Visitor.VisitorTransform.recurse_to_children |
| 1233 |
| 1234 def visit_SimpleCallNode(self, node): |
| 1235 self.visitchildren(node) |
| 1236 function = node.function |
| 1237 if not self._function_is_builtin_name(function): |
| 1238 return node |
| 1239 return self._dispatch_to_handler(node, function, node.args) |
| 1240 |
| 1241 def visit_GeneralCallNode(self, node): |
| 1242 self.visitchildren(node) |
| 1243 function = node.function |
| 1244 if not self._function_is_builtin_name(function): |
| 1245 return node |
| 1246 arg_tuple = node.positional_args |
| 1247 if not isinstance(arg_tuple, ExprNodes.TupleNode): |
| 1248 return node |
| 1249 args = arg_tuple.args |
| 1250 return self._dispatch_to_handler( |
| 1251 node, function, args, node.keyword_args) |
| 1252 |
| 1253 def _function_is_builtin_name(self, function): |
| 1254 if not function.is_name: |
| 1255 return False |
| 1256 env = self.current_env() |
| 1257 entry = env.lookup(function.name) |
| 1258 if entry is not env.builtin_scope().lookup_here(function.name): |
| 1259 return False |
| 1260 # if entry is None, it's at least an undeclared name, so likely builtin |
| 1261 return True |
| 1262 |
| 1263 def _dispatch_to_handler(self, node, function, args, kwargs=None): |
| 1264 if kwargs is None: |
| 1265 handler_name = '_handle_simple_function_%s' % function.name |
| 1266 else: |
| 1267 handler_name = '_handle_general_function_%s' % function.name |
| 1268 handle_call = getattr(self, handler_name, None) |
| 1269 if handle_call is not None: |
| 1270 if kwargs is None: |
| 1271 return handle_call(node, args) |
| 1272 else: |
| 1273 return handle_call(node, args, kwargs) |
| 1274 return node |
| 1275 |
| 1276 def _inject_capi_function(self, node, cname, func_type, utility_code=None): |
| 1277 node.function = ExprNodes.PythonCapiFunctionNode( |
| 1278 node.function.pos, node.function.name, cname, func_type, |
| 1279 utility_code = utility_code) |
| 1280 |
| 1281 def _error_wrong_arg_count(self, function_name, node, args, expected=None): |
| 1282 if not expected: # None or 0 |
| 1283 arg_str = '' |
| 1284 elif isinstance(expected, basestring) or expected > 1: |
| 1285 arg_str = '...' |
| 1286 elif expected == 1: |
| 1287 arg_str = 'x' |
| 1288 else: |
| 1289 arg_str = '' |
| 1290 if expected is not None: |
| 1291 expected_str = 'expected %s, ' % expected |
| 1292 else: |
| 1293 expected_str = '' |
| 1294 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" %
( |
| 1295 function_name, arg_str, expected_str, len(args))) |
| 1296 |
| 1297 # specific handlers for simple call nodes |
| 1298 |
| 1299 def _handle_simple_function_float(self, node, pos_args): |
| 1300 if not pos_args: |
| 1301 return ExprNodes.FloatNode(node.pos, value='0.0') |
| 1302 if len(pos_args) > 1: |
| 1303 self._error_wrong_arg_count('float', node, pos_args, 1) |
| 1304 arg_type = getattr(pos_args[0], 'type', None) |
| 1305 if arg_type in (PyrexTypes.c_double_type, Builtin.float_type): |
| 1306 return pos_args[0] |
| 1307 return node |
| 1308 |
| 1309 class YieldNodeCollector(Visitor.TreeVisitor): |
| 1310 def __init__(self): |
| 1311 Visitor.TreeVisitor.__init__(self) |
| 1312 self.yield_stat_nodes = {} |
| 1313 self.yield_nodes = [] |
| 1314 |
| 1315 visit_Node = Visitor.TreeVisitor.visitchildren |
| 1316 # XXX: disable inlining while it's not back supported |
| 1317 def __visit_YieldExprNode(self, node): |
| 1318 self.yield_nodes.append(node) |
| 1319 self.visitchildren(node) |
| 1320 |
| 1321 def __visit_ExprStatNode(self, node): |
| 1322 self.visitchildren(node) |
| 1323 if node.expr in self.yield_nodes: |
| 1324 self.yield_stat_nodes[node.expr] = node |
| 1325 |
| 1326 def __visit_GeneratorExpressionNode(self, node): |
| 1327 # enable when we support generic generator expressions |
| 1328 # |
| 1329 # everything below this node is out of scope |
| 1330 pass |
| 1331 |
| 1332 def _find_single_yield_expression(self, node): |
| 1333 collector = self.YieldNodeCollector() |
| 1334 collector.visitchildren(node) |
| 1335 if len(collector.yield_nodes) != 1: |
| 1336 return None, None |
| 1337 yield_node = collector.yield_nodes[0] |
| 1338 try: |
| 1339 return (yield_node.arg, collector.yield_stat_nodes[yield_node]) |
| 1340 except KeyError: |
| 1341 return None, None |
| 1342 |
| 1343 def _handle_simple_function_all(self, node, pos_args): |
| 1344 """Transform |
| 1345 |
| 1346 _result = all(x for L in LL for x in L) |
| 1347 |
| 1348 into |
| 1349 |
| 1350 for L in LL: |
| 1351 for x in L: |
| 1352 if not x: |
| 1353 _result = False |
| 1354 break |
| 1355 else: |
| 1356 continue |
| 1357 break |
| 1358 else: |
| 1359 _result = True |
| 1360 """ |
| 1361 return self._transform_any_all(node, pos_args, False) |
| 1362 |
| 1363 def _handle_simple_function_any(self, node, pos_args): |
| 1364 """Transform |
| 1365 |
| 1366 _result = any(x for L in LL for x in L) |
| 1367 |
| 1368 into |
| 1369 |
| 1370 for L in LL: |
| 1371 for x in L: |
| 1372 if x: |
| 1373 _result = True |
| 1374 break |
| 1375 else: |
| 1376 continue |
| 1377 break |
| 1378 else: |
| 1379 _result = False |
| 1380 """ |
| 1381 return self._transform_any_all(node, pos_args, True) |
| 1382 |
| 1383 def _transform_any_all(self, node, pos_args, is_any): |
| 1384 if len(pos_args) != 1: |
| 1385 return node |
| 1386 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): |
| 1387 return node |
| 1388 gen_expr_node = pos_args[0] |
| 1389 loop_node = gen_expr_node.loop |
| 1390 yield_expression, yield_stat_node = self._find_single_yield_expression(l
oop_node) |
| 1391 if yield_expression is None: |
| 1392 return node |
| 1393 |
| 1394 if is_any: |
| 1395 condition = yield_expression |
| 1396 else: |
| 1397 condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_
expression) |
| 1398 |
| 1399 result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bin
t_type) |
| 1400 test_node = Nodes.IfStatNode( |
| 1401 yield_expression.pos, |
| 1402 else_clause = None, |
| 1403 if_clauses = [ Nodes.IfClauseNode( |
| 1404 yield_expression.pos, |
| 1405 condition = condition, |
| 1406 body = Nodes.StatListNode( |
| 1407 node.pos, |
| 1408 stats = [ |
| 1409 Nodes.SingleAssignmentNode( |
| 1410 node.pos, |
| 1411 lhs = result_ref, |
| 1412 rhs = ExprNodes.BoolNode(yield_expression.pos, value
= is_any, |
| 1413 constant_result = is_any)), |
| 1414 Nodes.BreakStatNode(node.pos) |
| 1415 ])) ] |
| 1416 ) |
| 1417 loop = loop_node |
| 1418 while isinstance(loop.body, Nodes.LoopNode): |
| 1419 next_loop = loop.body |
| 1420 loop.body = Nodes.StatListNode(loop.body.pos, stats = [ |
| 1421 loop.body, |
| 1422 Nodes.BreakStatNode(yield_expression.pos) |
| 1423 ]) |
| 1424 next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos) |
| 1425 loop = next_loop |
| 1426 loop_node.else_clause = Nodes.SingleAssignmentNode( |
| 1427 node.pos, |
| 1428 lhs = result_ref, |
| 1429 rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any, |
| 1430 constant_result = not is_any)) |
| 1431 |
| 1432 Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node) |
| 1433 |
| 1434 return ExprNodes.InlinedGeneratorExpressionNode( |
| 1435 gen_expr_node.pos, loop = loop_node, result_node = result_ref, |
| 1436 expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any'
or 'all') |
| 1437 |
| 1438 def _handle_simple_function_sorted(self, node, pos_args): |
| 1439 """Transform sorted(genexpr) and sorted([listcomp]) into |
| 1440 [listcomp].sort(). CPython just reads the iterable into a |
| 1441 list and calls .sort() on it. Expanding the iterable in a |
| 1442 listcomp is still faster and the result can be sorted in |
| 1443 place. |
| 1444 """ |
| 1445 if len(pos_args) != 1: |
| 1446 return node |
| 1447 if isinstance(pos_args[0], ExprNodes.ComprehensionNode) \ |
| 1448 and pos_args[0].type is Builtin.list_type: |
| 1449 listcomp_node = pos_args[0] |
| 1450 loop_node = listcomp_node.loop |
| 1451 elif isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): |
| 1452 gen_expr_node = pos_args[0] |
| 1453 loop_node = gen_expr_node.loop |
| 1454 yield_expression, yield_stat_node = self._find_single_yield_expressi
on(loop_node) |
| 1455 if yield_expression is None: |
| 1456 return node |
| 1457 |
| 1458 append_node = ExprNodes.ComprehensionAppendNode( |
| 1459 yield_expression.pos, expr = yield_expression) |
| 1460 |
| 1461 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_
node) |
| 1462 |
| 1463 listcomp_node = ExprNodes.ComprehensionNode( |
| 1464 gen_expr_node.pos, loop = loop_node, |
| 1465 append = append_node, type = Builtin.list_type, |
| 1466 expr_scope = gen_expr_node.expr_scope, |
| 1467 has_local_scope = True) |
| 1468 append_node.target = listcomp_node |
| 1469 else: |
| 1470 return node |
| 1471 |
| 1472 result_node = UtilNodes.ResultRefNode( |
| 1473 pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False) |
| 1474 listcomp_assign_node = Nodes.SingleAssignmentNode( |
| 1475 node.pos, lhs = result_node, rhs = listcomp_node, first = True) |
| 1476 |
| 1477 sort_method = ExprNodes.AttributeNode( |
| 1478 node.pos, obj = result_node, attribute = EncodedString('sort'), |
| 1479 # entry ? type ? |
| 1480 needs_none_check = False) |
| 1481 sort_node = Nodes.ExprStatNode( |
| 1482 node.pos, expr = ExprNodes.SimpleCallNode( |
| 1483 node.pos, function = sort_method, args = [])) |
| 1484 |
| 1485 sort_node.analyse_declarations(self.current_env()) |
| 1486 |
| 1487 return UtilNodes.TempResultFromStatNode( |
| 1488 result_node, |
| 1489 Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_no
de ])) |
| 1490 |
| 1491 def _handle_simple_function_sum(self, node, pos_args): |
| 1492 """Transform sum(genexpr) into an equivalent inlined aggregation loop. |
| 1493 """ |
| 1494 if len(pos_args) not in (1,2): |
| 1495 return node |
| 1496 if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode, |
| 1497 ExprNodes.ComprehensionNode)): |
| 1498 return node |
| 1499 gen_expr_node = pos_args[0] |
| 1500 loop_node = gen_expr_node.loop |
| 1501 |
| 1502 if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode): |
| 1503 yield_expression, yield_stat_node = self._find_single_yield_expressi
on(loop_node) |
| 1504 if yield_expression is None: |
| 1505 return node |
| 1506 else: # ComprehensionNode |
| 1507 yield_stat_node = gen_expr_node.append |
| 1508 yield_expression = yield_stat_node.expr |
| 1509 try: |
| 1510 if not yield_expression.is_literal or not yield_expression.type.
is_int: |
| 1511 return node |
| 1512 except AttributeError: |
| 1513 return node # in case we don't have a type yet |
| 1514 # special case: old Py2 backwards compatible "sum([int_const for ...
])" |
| 1515 # can safely be unpacked into a genexpr |
| 1516 |
| 1517 if len(pos_args) == 1: |
| 1518 start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) |
| 1519 else: |
| 1520 start = pos_args[1] |
| 1521 |
| 1522 result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_ob
ject_type) |
| 1523 add_node = Nodes.SingleAssignmentNode( |
| 1524 yield_expression.pos, |
| 1525 lhs = result_ref, |
| 1526 rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expressi
on) |
| 1527 ) |
| 1528 |
| 1529 Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node) |
| 1530 |
| 1531 exec_code = Nodes.StatListNode( |
| 1532 node.pos, |
| 1533 stats = [ |
| 1534 Nodes.SingleAssignmentNode( |
| 1535 start.pos, |
| 1536 lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=resul
t_ref), |
| 1537 rhs = start, |
| 1538 first = True), |
| 1539 loop_node |
| 1540 ]) |
| 1541 |
| 1542 return ExprNodes.InlinedGeneratorExpressionNode( |
| 1543 gen_expr_node.pos, loop = exec_code, result_node = result_ref, |
| 1544 expr_scope = gen_expr_node.expr_scope, orig_func = 'sum', |
| 1545 has_local_scope = gen_expr_node.has_local_scope) |
| 1546 |
| 1547 def _handle_simple_function_min(self, node, pos_args): |
| 1548 return self._optimise_min_max(node, pos_args, '<') |
| 1549 |
| 1550 def _handle_simple_function_max(self, node, pos_args): |
| 1551 return self._optimise_min_max(node, pos_args, '>') |
| 1552 |
| 1553 def _optimise_min_max(self, node, args, operator): |
| 1554 """Replace min(a,b,...) and max(a,b,...) by explicit comparison code. |
| 1555 """ |
| 1556 if len(args) <= 1: |
| 1557 if len(args) == 1 and args[0].is_sequence_constructor: |
| 1558 args = args[0].args |
| 1559 else: |
| 1560 # leave this to Python |
| 1561 return node |
| 1562 |
| 1563 cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:])) |
| 1564 |
| 1565 last_result = args[0] |
| 1566 for arg_node in cascaded_nodes: |
| 1567 result_ref = UtilNodes.ResultRefNode(last_result) |
| 1568 last_result = ExprNodes.CondExprNode( |
| 1569 arg_node.pos, |
| 1570 true_val = arg_node, |
| 1571 false_val = result_ref, |
| 1572 test = ExprNodes.PrimaryCmpNode( |
| 1573 arg_node.pos, |
| 1574 operand1 = arg_node, |
| 1575 operator = operator, |
| 1576 operand2 = result_ref, |
| 1577 ) |
| 1578 ) |
| 1579 last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result
) |
| 1580 |
| 1581 for ref_node in cascaded_nodes[::-1]: |
| 1582 last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result) |
| 1583 |
| 1584 return last_result |
| 1585 |
| 1586 def _DISABLED_handle_simple_function_tuple(self, node, pos_args): |
| 1587 if not pos_args: |
| 1588 return ExprNodes.TupleNode(node.pos, args=[], constant_result=()) |
| 1589 # This is a bit special - for iterables (including genexps), |
| 1590 # Python actually overallocates and resizes a newly created |
| 1591 # tuple incrementally while reading items, which we can't |
| 1592 # easily do without explicit node support. Instead, we read |
| 1593 # the items into a list and then copy them into a tuple of the |
| 1594 # final size. This takes up to twice as much memory, but will |
| 1595 # have to do until we have real support for genexps. |
| 1596 result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_t
ype) |
| 1597 if result is not node: |
| 1598 return ExprNodes.AsTupleNode(node.pos, arg=result) |
| 1599 return node |
| 1600 |
| 1601 def _handle_simple_function_frozenset(self, node, pos_args): |
| 1602 """Replace frozenset([...]) by frozenset((...)) as tuples are more effic
ient. |
| 1603 """ |
| 1604 if len(pos_args) != 1: |
| 1605 return node |
| 1606 if pos_args[0].is_sequence_constructor and not pos_args[0].args: |
| 1607 del pos_args[0] |
| 1608 elif isinstance(pos_args[0], ExprNodes.ListNode): |
| 1609 pos_args[0] = pos_args[0].as_tuple() |
| 1610 return node |
| 1611 |
| 1612 def _handle_simple_function_list(self, node, pos_args): |
| 1613 if not pos_args: |
| 1614 return ExprNodes.ListNode(node.pos, args=[], constant_result=[]) |
| 1615 return self._transform_list_set_genexpr(node, pos_args, Builtin.list_typ
e) |
| 1616 |
| 1617 def _handle_simple_function_set(self, node, pos_args): |
| 1618 if not pos_args: |
| 1619 return ExprNodes.SetNode(node.pos, args=[], constant_result=set()) |
| 1620 return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type
) |
| 1621 |
| 1622 def _transform_list_set_genexpr(self, node, pos_args, target_type): |
| 1623 """Replace set(genexpr) and list(genexpr) by a literal comprehension. |
| 1624 """ |
| 1625 if len(pos_args) > 1: |
| 1626 return node |
| 1627 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): |
| 1628 return node |
| 1629 gen_expr_node = pos_args[0] |
| 1630 loop_node = gen_expr_node.loop |
| 1631 |
| 1632 yield_expression, yield_stat_node = self._find_single_yield_expression(l
oop_node) |
| 1633 if yield_expression is None: |
| 1634 return node |
| 1635 |
| 1636 append_node = ExprNodes.ComprehensionAppendNode( |
| 1637 yield_expression.pos, |
| 1638 expr = yield_expression) |
| 1639 |
| 1640 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node
) |
| 1641 |
| 1642 comp = ExprNodes.ComprehensionNode( |
| 1643 node.pos, |
| 1644 has_local_scope = True, |
| 1645 expr_scope = gen_expr_node.expr_scope, |
| 1646 loop = loop_node, |
| 1647 append = append_node, |
| 1648 type = target_type) |
| 1649 append_node.target = comp |
| 1650 return comp |
| 1651 |
| 1652 def _handle_simple_function_dict(self, node, pos_args): |
| 1653 """Replace dict( (a,b) for ... ) by a literal { a:b for ... }. |
| 1654 """ |
| 1655 if len(pos_args) == 0: |
| 1656 return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_res
ult={}) |
| 1657 if len(pos_args) > 1: |
| 1658 return node |
| 1659 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): |
| 1660 return node |
| 1661 gen_expr_node = pos_args[0] |
| 1662 loop_node = gen_expr_node.loop |
| 1663 |
| 1664 yield_expression, yield_stat_node = self._find_single_yield_expression(l
oop_node) |
| 1665 if yield_expression is None: |
| 1666 return node |
| 1667 |
| 1668 if not isinstance(yield_expression, ExprNodes.TupleNode): |
| 1669 return node |
| 1670 if len(yield_expression.args) != 2: |
| 1671 return node |
| 1672 |
| 1673 append_node = ExprNodes.DictComprehensionAppendNode( |
| 1674 yield_expression.pos, |
| 1675 key_expr = yield_expression.args[0], |
| 1676 value_expr = yield_expression.args[1]) |
| 1677 |
| 1678 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node
) |
| 1679 |
| 1680 dictcomp = ExprNodes.ComprehensionNode( |
| 1681 node.pos, |
| 1682 has_local_scope = True, |
| 1683 expr_scope = gen_expr_node.expr_scope, |
| 1684 loop = loop_node, |
| 1685 append = append_node, |
| 1686 type = Builtin.dict_type) |
| 1687 append_node.target = dictcomp |
| 1688 return dictcomp |
| 1689 |
| 1690 # specific handlers for general call nodes |
| 1691 |
| 1692 def _handle_general_function_dict(self, node, pos_args, kwargs): |
| 1693 """Replace dict(a=b,c=d,...) by the underlying keyword dict |
| 1694 construction which is done anyway. |
| 1695 """ |
| 1696 if len(pos_args) > 0: |
| 1697 return node |
| 1698 if not isinstance(kwargs, ExprNodes.DictNode): |
| 1699 return node |
| 1700 return kwargs |
| 1701 |
| 1702 |
| 1703 class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform): |
| 1704 visit_Node = Visitor.VisitorTransform.recurse_to_children |
| 1705 |
| 1706 def get_constant_value_node(self, name_node): |
| 1707 if name_node.cf_state is None: |
| 1708 return None |
| 1709 if name_node.cf_state.cf_is_null: |
| 1710 return None |
| 1711 entry = self.current_env().lookup(name_node.name) |
| 1712 if not entry or (not entry.cf_assignments |
| 1713 or len(entry.cf_assignments) != 1): |
| 1714 # not just a single assignment in all closures |
| 1715 return None |
| 1716 return entry.cf_assignments[0].rhs |
| 1717 |
| 1718 def visit_SimpleCallNode(self, node): |
| 1719 self.visitchildren(node) |
| 1720 if not self.current_directives.get('optimize.inline_defnode_calls'): |
| 1721 return node |
| 1722 function_name = node.function |
| 1723 if not function_name.is_name: |
| 1724 return node |
| 1725 function = self.get_constant_value_node(function_name) |
| 1726 if not isinstance(function, ExprNodes.PyCFunctionNode): |
| 1727 return node |
| 1728 inlined = ExprNodes.InlinedDefNodeCallNode( |
| 1729 node.pos, function_name=function_name, |
| 1730 function=function, args=node.args) |
| 1731 if inlined.can_be_inlined(): |
| 1732 return self.replace(node, inlined) |
| 1733 return node |
| 1734 |
| 1735 |
| 1736 class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform): |
| 1737 """Optimize some common methods calls and instantiation patterns |
| 1738 for builtin types *after* the type analysis phase. |
| 1739 |
| 1740 Running after type analysis, this transform can only perform |
| 1741 function replacements that do not alter the function return type |
| 1742 in a way that was not anticipated by the type analysis. |
| 1743 """ |
| 1744 ### cleanup to avoid redundant coercions to/from Python types |
| 1745 |
| 1746 def _visit_PyTypeTestNode(self, node): |
| 1747 # disabled - appears to break assignments in some cases, and |
| 1748 # also drops a None check, which might still be required |
| 1749 """Flatten redundant type checks after tree changes. |
| 1750 """ |
| 1751 old_arg = node.arg |
| 1752 self.visitchildren(node) |
| 1753 if old_arg is node.arg or node.arg.type != node.type: |
| 1754 return node |
| 1755 return node.arg |
| 1756 |
| 1757 def _visit_TypecastNode(self, node): |
| 1758 # disabled - the user may have had a reason to put a type |
| 1759 # cast, even if it looks redundant to Cython |
| 1760 """ |
| 1761 Drop redundant type casts. |
| 1762 """ |
| 1763 self.visitchildren(node) |
| 1764 if node.type == node.operand.type: |
| 1765 return node.operand |
| 1766 return node |
| 1767 |
| 1768 def visit_ExprStatNode(self, node): |
| 1769 """ |
| 1770 Drop useless coercions. |
| 1771 """ |
| 1772 self.visitchildren(node) |
| 1773 if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode): |
| 1774 node.expr = node.expr.arg |
| 1775 return node |
| 1776 |
| 1777 def visit_CoerceToBooleanNode(self, node): |
| 1778 """Drop redundant conversion nodes after tree changes. |
| 1779 """ |
| 1780 self.visitchildren(node) |
| 1781 arg = node.arg |
| 1782 if isinstance(arg, ExprNodes.PyTypeTestNode): |
| 1783 arg = arg.arg |
| 1784 if isinstance(arg, ExprNodes.CoerceToPyTypeNode): |
| 1785 if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type): |
| 1786 return arg.arg.coerce_to_boolean(self.current_env()) |
| 1787 return node |
| 1788 |
| 1789 def visit_CoerceFromPyTypeNode(self, node): |
| 1790 """Drop redundant conversion nodes after tree changes. |
| 1791 |
| 1792 Also, optimise away calls to Python's builtin int() and |
| 1793 float() if the result is going to be coerced back into a C |
| 1794 type anyway. |
| 1795 """ |
| 1796 self.visitchildren(node) |
| 1797 arg = node.arg |
| 1798 if not arg.type.is_pyobject: |
| 1799 # no Python conversion left at all, just do a C coercion instead |
| 1800 if node.type == arg.type: |
| 1801 return arg |
| 1802 else: |
| 1803 return arg.coerce_to(node.type, self.current_env()) |
| 1804 if isinstance(arg, ExprNodes.PyTypeTestNode): |
| 1805 arg = arg.arg |
| 1806 if arg.is_literal: |
| 1807 if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or |
| 1808 node.type.is_float and isinstance(arg, ExprNodes.FloatNode)
or |
| 1809 node.type.is_int and isinstance(arg, ExprNodes.BoolNode)): |
| 1810 return arg.coerce_to(node.type, self.current_env()) |
| 1811 elif isinstance(arg, ExprNodes.CoerceToPyTypeNode): |
| 1812 if arg.type is PyrexTypes.py_object_type: |
| 1813 if node.type.assignable_from(arg.arg.type): |
| 1814 # completely redundant C->Py->C coercion |
| 1815 return arg.arg.coerce_to(node.type, self.current_env()) |
| 1816 elif isinstance(arg, ExprNodes.SimpleCallNode): |
| 1817 if node.type.is_int or node.type.is_float: |
| 1818 return self._optimise_numeric_cast_call(node, arg) |
| 1819 elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access: |
| 1820 index_node = arg.index |
| 1821 if isinstance(index_node, ExprNodes.CoerceToPyTypeNode): |
| 1822 index_node = index_node.arg |
| 1823 if index_node.type.is_int: |
| 1824 return self._optimise_int_indexing(node, arg, index_node) |
| 1825 return node |
| 1826 |
| 1827 PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType( |
| 1828 PyrexTypes.c_char_type, [ |
| 1829 PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None), |
| 1830 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None)
, |
| 1831 PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None)
, |
| 1832 ], |
| 1833 exception_value = "((char)-1)", |
| 1834 exception_check = True) |
| 1835 |
| 1836 def _optimise_int_indexing(self, coerce_node, arg, index_node): |
| 1837 env = self.current_env() |
| 1838 bound_check_bool = env.directives['boundscheck'] and 1 or 0 |
| 1839 if arg.base.type is Builtin.bytes_type: |
| 1840 if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_t
ype): |
| 1841 # bytes[index] -> char |
| 1842 bound_check_node = ExprNodes.IntNode( |
| 1843 coerce_node.pos, value=str(bound_check_bool), |
| 1844 constant_result=bound_check_bool) |
| 1845 node = ExprNodes.PythonCapiCallNode( |
| 1846 coerce_node.pos, "__Pyx_PyBytes_GetItemInt", |
| 1847 self.PyBytes_GetItemInt_func_type, |
| 1848 args=[ |
| 1849 arg.base.as_none_safe_node("'NoneType' object is not sub
scriptable"), |
| 1850 index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env), |
| 1851 bound_check_node, |
| 1852 ], |
| 1853 is_temp=True, |
| 1854 utility_code=UtilityCode.load_cached( |
| 1855 'bytes_index', 'StringTools.c')) |
| 1856 if coerce_node.type is not PyrexTypes.c_char_type: |
| 1857 node = node.coerce_to(coerce_node.type, env) |
| 1858 return node |
| 1859 return coerce_node |
| 1860 |
| 1861 def _optimise_numeric_cast_call(self, node, arg): |
| 1862 function = arg.function |
| 1863 if not isinstance(function, ExprNodes.NameNode) \ |
| 1864 or not function.type.is_builtin_type \ |
| 1865 or not isinstance(arg.arg_tuple, ExprNodes.TupleNode): |
| 1866 return node |
| 1867 args = arg.arg_tuple.args |
| 1868 if len(args) != 1: |
| 1869 return node |
| 1870 func_arg = args[0] |
| 1871 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): |
| 1872 func_arg = func_arg.arg |
| 1873 elif func_arg.type.is_pyobject: |
| 1874 # play safe: Python conversion might work on all sorts of things |
| 1875 return node |
| 1876 if function.name == 'int': |
| 1877 if func_arg.type.is_int or node.type.is_int: |
| 1878 if func_arg.type == node.type: |
| 1879 return func_arg |
| 1880 elif node.type.assignable_from(func_arg.type) or func_arg.type.i
s_float: |
| 1881 return ExprNodes.TypecastNode( |
| 1882 node.pos, operand=func_arg, type=node.type) |
| 1883 elif function.name == 'float': |
| 1884 if func_arg.type.is_float or node.type.is_float: |
| 1885 if func_arg.type == node.type: |
| 1886 return func_arg |
| 1887 elif node.type.assignable_from(func_arg.type) or func_arg.type.i
s_float: |
| 1888 return ExprNodes.TypecastNode( |
| 1889 node.pos, operand=func_arg, type=node.type) |
| 1890 return node |
| 1891 |
| 1892 def _error_wrong_arg_count(self, function_name, node, args, expected=None): |
| 1893 if not expected: # None or 0 |
| 1894 arg_str = '' |
| 1895 elif isinstance(expected, basestring) or expected > 1: |
| 1896 arg_str = '...' |
| 1897 elif expected == 1: |
| 1898 arg_str = 'x' |
| 1899 else: |
| 1900 arg_str = '' |
| 1901 if expected is not None: |
| 1902 expected_str = 'expected %s, ' % expected |
| 1903 else: |
| 1904 expected_str = '' |
| 1905 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" %
( |
| 1906 function_name, arg_str, expected_str, len(args))) |
| 1907 |
| 1908 ### generic fallbacks |
| 1909 |
| 1910 def _handle_function(self, node, function_name, function, arg_list, kwargs): |
| 1911 return node |
| 1912 |
| 1913 def _handle_method(self, node, type_name, attr_name, function, |
| 1914 arg_list, is_unbound_method, kwargs): |
| 1915 """ |
| 1916 Try to inject C-API calls for unbound method calls to builtin types. |
| 1917 While the method declarations in Builtin.py already handle this, we |
| 1918 can additionally resolve bound and unbound methods here that were |
| 1919 assigned to variables ahead of time. |
| 1920 """ |
| 1921 if kwargs: |
| 1922 return node |
| 1923 if not function or not function.is_attribute or not function.obj.is_name
: |
| 1924 # cannot track unbound method calls over more than one indirection a
s |
| 1925 # the names might have been reassigned in the meantime |
| 1926 return node |
| 1927 type_entry = self.current_env().lookup(type_name) |
| 1928 if not type_entry: |
| 1929 return node |
| 1930 method = ExprNodes.AttributeNode( |
| 1931 node.function.pos, |
| 1932 obj=ExprNodes.NameNode( |
| 1933 function.pos, |
| 1934 name=type_name, |
| 1935 entry=type_entry, |
| 1936 type=type_entry.type), |
| 1937 attribute=attr_name, |
| 1938 is_called=True).analyse_as_unbound_cmethod_node(self.current_env()) |
| 1939 if method is None: |
| 1940 return node |
| 1941 args = node.args |
| 1942 if args is None and node.arg_tuple: |
| 1943 args = node.arg_tuple.args |
| 1944 call_node = ExprNodes.SimpleCallNode( |
| 1945 node.pos, |
| 1946 function=method, |
| 1947 args=args) |
| 1948 if not is_unbound_method: |
| 1949 call_node.self = function.obj |
| 1950 call_node.analyse_c_function_call(self.current_env()) |
| 1951 call_node.analysed = True |
| 1952 return call_node.coerce_to(node.type, self.current_env()) |
| 1953 |
| 1954 ### builtin types |
| 1955 |
| 1956 PyDict_Copy_func_type = PyrexTypes.CFuncType( |
| 1957 Builtin.dict_type, [ |
| 1958 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None) |
| 1959 ]) |
| 1960 |
| 1961 def _handle_simple_function_dict(self, node, function, pos_args): |
| 1962 """Replace dict(some_dict) by PyDict_Copy(some_dict). |
| 1963 """ |
| 1964 if len(pos_args) != 1: |
| 1965 return node |
| 1966 arg = pos_args[0] |
| 1967 if arg.type is Builtin.dict_type: |
| 1968 arg = arg.as_none_safe_node("'NoneType' is not iterable") |
| 1969 return ExprNodes.PythonCapiCallNode( |
| 1970 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type, |
| 1971 args = [arg], |
| 1972 is_temp = node.is_temp |
| 1973 ) |
| 1974 return node |
| 1975 |
| 1976 PyList_AsTuple_func_type = PyrexTypes.CFuncType( |
| 1977 Builtin.tuple_type, [ |
| 1978 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None) |
| 1979 ]) |
| 1980 |
| 1981 def _handle_simple_function_tuple(self, node, function, pos_args): |
| 1982 """Replace tuple([...]) by a call to PyList_AsTuple. |
| 1983 """ |
| 1984 if len(pos_args) != 1: |
| 1985 return node |
| 1986 arg = pos_args[0] |
| 1987 if arg.type is Builtin.tuple_type and not arg.may_be_none(): |
| 1988 return arg |
| 1989 if arg.type is not Builtin.list_type: |
| 1990 return node |
| 1991 pos_args[0] = arg.as_none_safe_node( |
| 1992 "'NoneType' object is not iterable") |
| 1993 |
| 1994 return ExprNodes.PythonCapiCallNode( |
| 1995 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type, |
| 1996 args = pos_args, |
| 1997 is_temp = node.is_temp |
| 1998 ) |
| 1999 |
| 2000 PySet_New_func_type = PyrexTypes.CFuncType( |
| 2001 Builtin.set_type, [ |
| 2002 PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None) |
| 2003 ]) |
| 2004 |
| 2005 def _handle_simple_function_set(self, node, function, pos_args): |
| 2006 if len(pos_args) != 1: |
| 2007 return node |
| 2008 if pos_args[0].is_sequence_constructor: |
| 2009 # We can optimise set([x,y,z]) safely into a set literal, |
| 2010 # but only if we create all items before adding them - |
| 2011 # adding an item may raise an exception if it is not |
| 2012 # hashable, but creating the later items may have |
| 2013 # side-effects. |
| 2014 args = [] |
| 2015 temps = [] |
| 2016 for arg in pos_args[0].args: |
| 2017 if not arg.is_simple(): |
| 2018 arg = UtilNodes.LetRefNode(arg) |
| 2019 temps.append(arg) |
| 2020 args.append(arg) |
| 2021 result = ExprNodes.SetNode(node.pos, is_temp=1, args=args) |
| 2022 for temp in temps[::-1]: |
| 2023 result = UtilNodes.EvalWithTempExprNode(temp, result) |
| 2024 return result |
| 2025 else: |
| 2026 # PySet_New(it) is better than a generic Python call to set(it) |
| 2027 return ExprNodes.PythonCapiCallNode( |
| 2028 node.pos, "PySet_New", |
| 2029 self.PySet_New_func_type, |
| 2030 args=pos_args, |
| 2031 is_temp=node.is_temp, |
| 2032 utility_code=UtilityCode.load_cached('pyset_compat', 'Builtins.c
'), |
| 2033 py_name="set") |
| 2034 |
| 2035 PyFrozenSet_New_func_type = PyrexTypes.CFuncType( |
| 2036 Builtin.frozenset_type, [ |
| 2037 PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None) |
| 2038 ]) |
| 2039 |
| 2040 def _handle_simple_function_frozenset(self, node, function, pos_args): |
| 2041 if not pos_args: |
| 2042 pos_args = [ExprNodes.NullNode(node.pos)] |
| 2043 elif len(pos_args) > 1: |
| 2044 return node |
| 2045 elif pos_args[0].type is Builtin.frozenset_type and not pos_args[0].may_
be_none(): |
| 2046 return pos_args[0] |
| 2047 # PyFrozenSet_New(it) is better than a generic Python call to frozenset(
it) |
| 2048 return ExprNodes.PythonCapiCallNode( |
| 2049 node.pos, "__Pyx_PyFrozenSet_New", |
| 2050 self.PyFrozenSet_New_func_type, |
| 2051 args=pos_args, |
| 2052 is_temp=node.is_temp, |
| 2053 utility_code=UtilityCode.load_cached('pyfrozenset_new', 'Builtins.c'
), |
| 2054 py_name="frozenset") |
| 2055 |
| 2056 PyObject_AsDouble_func_type = PyrexTypes.CFuncType( |
| 2057 PyrexTypes.c_double_type, [ |
| 2058 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None), |
| 2059 ], |
| 2060 exception_value = "((double)-1)", |
| 2061 exception_check = True) |
| 2062 |
| 2063 def _handle_simple_function_float(self, node, function, pos_args): |
| 2064 """Transform float() into either a C type cast or a faster C |
| 2065 function call. |
| 2066 """ |
| 2067 # Note: this requires the float() function to be typed as |
| 2068 # returning a C 'double' |
| 2069 if len(pos_args) == 0: |
| 2070 return ExprNodes.FloatNode( |
| 2071 node, value="0.0", constant_result=0.0 |
| 2072 ).coerce_to(Builtin.float_type, self.current_env()) |
| 2073 elif len(pos_args) != 1: |
| 2074 self._error_wrong_arg_count('float', node, pos_args, '0 or 1') |
| 2075 return node |
| 2076 func_arg = pos_args[0] |
| 2077 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): |
| 2078 func_arg = func_arg.arg |
| 2079 if func_arg.type is PyrexTypes.c_double_type: |
| 2080 return func_arg |
| 2081 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeri
c: |
| 2082 return ExprNodes.TypecastNode( |
| 2083 node.pos, operand=func_arg, type=node.type) |
| 2084 return ExprNodes.PythonCapiCallNode( |
| 2085 node.pos, "__Pyx_PyObject_AsDouble", |
| 2086 self.PyObject_AsDouble_func_type, |
| 2087 args = pos_args, |
| 2088 is_temp = node.is_temp, |
| 2089 utility_code = load_c_utility('pyobject_as_double'), |
| 2090 py_name = "float") |
| 2091 |
| 2092 PyNumber_Int_func_type = PyrexTypes.CFuncType( |
| 2093 PyrexTypes.py_object_type, [ |
| 2094 PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None) |
| 2095 ]) |
| 2096 |
| 2097 def _handle_simple_function_int(self, node, function, pos_args): |
| 2098 """Transform int() into a faster C function call. |
| 2099 """ |
| 2100 if len(pos_args) == 0: |
| 2101 return ExprNodes.IntNode(node, value="0", constant_result=0, |
| 2102 type=PyrexTypes.py_object_type) |
| 2103 elif len(pos_args) != 1: |
| 2104 return node # int(x, base) |
| 2105 func_arg = pos_args[0] |
| 2106 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): |
| 2107 return node # handled in visit_CoerceFromPyTypeNode() |
| 2108 if func_arg.type.is_pyobject and node.type.is_pyobject: |
| 2109 return ExprNodes.PythonCapiCallNode( |
| 2110 node.pos, "PyNumber_Int", self.PyNumber_Int_func_type, |
| 2111 args=pos_args, is_temp=True) |
| 2112 return node |
| 2113 |
| 2114 def _handle_simple_function_bool(self, node, function, pos_args): |
| 2115 """Transform bool(x) into a type coercion to a boolean. |
| 2116 """ |
| 2117 if len(pos_args) == 0: |
| 2118 return ExprNodes.BoolNode( |
| 2119 node.pos, value=False, constant_result=False |
| 2120 ).coerce_to(Builtin.bool_type, self.current_env()) |
| 2121 elif len(pos_args) != 1: |
| 2122 self._error_wrong_arg_count('bool', node, pos_args, '0 or 1') |
| 2123 return node |
| 2124 else: |
| 2125 # => !!<bint>(x) to make sure it's exactly 0 or 1 |
| 2126 operand = pos_args[0].coerce_to_boolean(self.current_env()) |
| 2127 operand = ExprNodes.NotNode(node.pos, operand = operand) |
| 2128 operand = ExprNodes.NotNode(node.pos, operand = operand) |
| 2129 # coerce back to Python object as that's the result we are expecting |
| 2130 return operand.coerce_to_pyobject(self.current_env()) |
| 2131 |
| 2132 ### builtin functions |
| 2133 |
| 2134 Pyx_strlen_func_type = PyrexTypes.CFuncType( |
| 2135 PyrexTypes.c_size_t_type, [ |
| 2136 PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None) |
| 2137 ]) |
| 2138 |
| 2139 Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType( |
| 2140 PyrexTypes.c_size_t_type, [ |
| 2141 PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_py_unicode_ptr_type,
None) |
| 2142 ]) |
| 2143 |
| 2144 PyObject_Size_func_type = PyrexTypes.CFuncType( |
| 2145 PyrexTypes.c_py_ssize_t_type, [ |
| 2146 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None) |
| 2147 ], |
| 2148 exception_value="-1") |
| 2149 |
| 2150 _map_to_capi_len_function = { |
| 2151 Builtin.unicode_type : "__Pyx_PyUnicode_GET_LENGTH", |
| 2152 Builtin.bytes_type : "PyBytes_GET_SIZE", |
| 2153 Builtin.list_type : "PyList_GET_SIZE", |
| 2154 Builtin.tuple_type : "PyTuple_GET_SIZE", |
| 2155 Builtin.dict_type : "PyDict_Size", |
| 2156 Builtin.set_type : "PySet_Size", |
| 2157 Builtin.frozenset_type : "PySet_Size", |
| 2158 }.get |
| 2159 |
| 2160 _ext_types_with_pysize = set(["cpython.array.array"]) |
| 2161 |
| 2162 def _handle_simple_function_len(self, node, function, pos_args): |
| 2163 """Replace len(char*) by the equivalent call to strlen(), |
| 2164 len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and |
| 2165 len(known_builtin_type) by an equivalent C-API call. |
| 2166 """ |
| 2167 if len(pos_args) != 1: |
| 2168 self._error_wrong_arg_count('len', node, pos_args, 1) |
| 2169 return node |
| 2170 arg = pos_args[0] |
| 2171 if isinstance(arg, ExprNodes.CoerceToPyTypeNode): |
| 2172 arg = arg.arg |
| 2173 if arg.type.is_string: |
| 2174 new_node = ExprNodes.PythonCapiCallNode( |
| 2175 node.pos, "strlen", self.Pyx_strlen_func_type, |
| 2176 args = [arg], |
| 2177 is_temp = node.is_temp, |
| 2178 utility_code = UtilityCode.load_cached("IncludeStringH", "String
Tools.c")) |
| 2179 elif arg.type.is_pyunicode_ptr: |
| 2180 new_node = ExprNodes.PythonCapiCallNode( |
| 2181 node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_
func_type, |
| 2182 args = [arg], |
| 2183 is_temp = node.is_temp) |
| 2184 elif arg.type.is_pyobject: |
| 2185 cfunc_name = self._map_to_capi_len_function(arg.type) |
| 2186 if cfunc_name is None: |
| 2187 arg_type = arg.type |
| 2188 if ((arg_type.is_extension_type or arg_type.is_builtin_type) |
| 2189 and arg_type.entry.qualified_name in self._ext_types_with_py
size): |
| 2190 cfunc_name = 'Py_SIZE' |
| 2191 else: |
| 2192 return node |
| 2193 arg = arg.as_none_safe_node( |
| 2194 "object of type 'NoneType' has no len()") |
| 2195 new_node = ExprNodes.PythonCapiCallNode( |
| 2196 node.pos, cfunc_name, self.PyObject_Size_func_type, |
| 2197 args = [arg], |
| 2198 is_temp = node.is_temp) |
| 2199 elif arg.type.is_unicode_char: |
| 2200 return ExprNodes.IntNode(node.pos, value='1', constant_result=1, |
| 2201 type=node.type) |
| 2202 else: |
| 2203 return node |
| 2204 if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_t
ype): |
| 2205 new_node = new_node.coerce_to(node.type, self.current_env()) |
| 2206 return new_node |
| 2207 |
| 2208 Pyx_Type_func_type = PyrexTypes.CFuncType( |
| 2209 Builtin.type_type, [ |
| 2210 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None) |
| 2211 ]) |
| 2212 |
| 2213 def _handle_simple_function_type(self, node, function, pos_args): |
| 2214 """Replace type(o) by a macro call to Py_TYPE(o). |
| 2215 """ |
| 2216 if len(pos_args) != 1: |
| 2217 return node |
| 2218 node = ExprNodes.PythonCapiCallNode( |
| 2219 node.pos, "Py_TYPE", self.Pyx_Type_func_type, |
| 2220 args = pos_args, |
| 2221 is_temp = False) |
| 2222 return ExprNodes.CastNode(node, PyrexTypes.py_object_type) |
| 2223 |
| 2224 Py_type_check_func_type = PyrexTypes.CFuncType( |
| 2225 PyrexTypes.c_bint_type, [ |
| 2226 PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None) |
| 2227 ]) |
| 2228 |
| 2229 def _handle_simple_function_isinstance(self, node, function, pos_args): |
| 2230 """Replace isinstance() checks against builtin types by the |
| 2231 corresponding C-API call. |
| 2232 """ |
| 2233 if len(pos_args) != 2: |
| 2234 return node |
| 2235 arg, types = pos_args |
| 2236 temp = None |
| 2237 if isinstance(types, ExprNodes.TupleNode): |
| 2238 types = types.args |
| 2239 if arg.is_attribute or not arg.is_simple(): |
| 2240 arg = temp = UtilNodes.ResultRefNode(arg) |
| 2241 elif types.type is Builtin.type_type: |
| 2242 types = [types] |
| 2243 else: |
| 2244 return node |
| 2245 |
| 2246 tests = [] |
| 2247 test_nodes = [] |
| 2248 env = self.current_env() |
| 2249 for test_type_node in types: |
| 2250 builtin_type = None |
| 2251 if test_type_node.is_name: |
| 2252 if test_type_node.entry: |
| 2253 entry = env.lookup(test_type_node.entry.name) |
| 2254 if entry and entry.type and entry.type.is_builtin_type: |
| 2255 builtin_type = entry.type |
| 2256 if builtin_type is Builtin.type_type: |
| 2257 # all types have type "type", but there's only one 'type' |
| 2258 if entry.name != 'type' or not ( |
| 2259 entry.scope and entry.scope.is_builtin_scope): |
| 2260 builtin_type = None |
| 2261 if builtin_type is not None: |
| 2262 type_check_function = entry.type.type_check_function(exact=False
) |
| 2263 if type_check_function in tests: |
| 2264 continue |
| 2265 tests.append(type_check_function) |
| 2266 type_check_args = [arg] |
| 2267 elif test_type_node.type is Builtin.type_type: |
| 2268 type_check_function = '__Pyx_TypeCheck' |
| 2269 type_check_args = [arg, test_type_node] |
| 2270 else: |
| 2271 return node |
| 2272 test_nodes.append( |
| 2273 ExprNodes.PythonCapiCallNode( |
| 2274 test_type_node.pos, type_check_function, self.Py_type_check_
func_type, |
| 2275 args = type_check_args, |
| 2276 is_temp = True, |
| 2277 )) |
| 2278 |
| 2279 def join_with_or(a,b, make_binop_node=ExprNodes.binop_node): |
| 2280 or_node = make_binop_node(node.pos, 'or', a, b) |
| 2281 or_node.type = PyrexTypes.c_bint_type |
| 2282 or_node.is_temp = True |
| 2283 return or_node |
| 2284 |
| 2285 test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env) |
| 2286 if temp is not None: |
| 2287 test_node = UtilNodes.EvalWithTempExprNode(temp, test_node) |
| 2288 return test_node |
| 2289 |
| 2290 def _handle_simple_function_ord(self, node, function, pos_args): |
| 2291 """Unpack ord(Py_UNICODE) and ord('X'). |
| 2292 """ |
| 2293 if len(pos_args) != 1: |
| 2294 return node |
| 2295 arg = pos_args[0] |
| 2296 if isinstance(arg, ExprNodes.CoerceToPyTypeNode): |
| 2297 if arg.arg.type.is_unicode_char: |
| 2298 return ExprNodes.TypecastNode( |
| 2299 arg.pos, operand=arg.arg, type=PyrexTypes.c_int_type |
| 2300 ).coerce_to(node.type, self.current_env()) |
| 2301 elif isinstance(arg, ExprNodes.UnicodeNode): |
| 2302 if len(arg.value) == 1: |
| 2303 return ExprNodes.IntNode( |
| 2304 arg.pos, type=PyrexTypes.c_int_type, |
| 2305 value=str(ord(arg.value)), |
| 2306 constant_result=ord(arg.value) |
| 2307 ).coerce_to(node.type, self.current_env()) |
| 2308 elif isinstance(arg, ExprNodes.StringNode): |
| 2309 if arg.unicode_value and len(arg.unicode_value) == 1 \ |
| 2310 and ord(arg.unicode_value) <= 255: # Py2/3 portability |
| 2311 return ExprNodes.IntNode( |
| 2312 arg.pos, type=PyrexTypes.c_int_type, |
| 2313 value=str(ord(arg.unicode_value)), |
| 2314 constant_result=ord(arg.unicode_value) |
| 2315 ).coerce_to(node.type, self.current_env()) |
| 2316 return node |
| 2317 |
| 2318 ### special methods |
| 2319 |
| 2320 Pyx_tp_new_func_type = PyrexTypes.CFuncType( |
| 2321 PyrexTypes.py_object_type, [ |
| 2322 PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None), |
| 2323 PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None), |
| 2324 ]) |
| 2325 |
| 2326 Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType( |
| 2327 PyrexTypes.py_object_type, [ |
| 2328 PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None), |
| 2329 PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None), |
| 2330 PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None), |
| 2331 ]) |
| 2332 |
| 2333 def _handle_any_slot__new__(self, node, function, args, |
| 2334 is_unbound_method, kwargs=None): |
| 2335 """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new(
) |
| 2336 """ |
| 2337 obj = function.obj |
| 2338 if not is_unbound_method or len(args) < 1: |
| 2339 return node |
| 2340 type_arg = args[0] |
| 2341 if not obj.is_name or not type_arg.is_name: |
| 2342 # play safe |
| 2343 return node |
| 2344 if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type: |
| 2345 # not a known type, play safe |
| 2346 return node |
| 2347 if not type_arg.type_entry or not obj.type_entry: |
| 2348 if obj.name != type_arg.name: |
| 2349 return node |
| 2350 # otherwise, we know it's a type and we know it's the same |
| 2351 # type for both - that should do |
| 2352 elif type_arg.type_entry != obj.type_entry: |
| 2353 # different types - may or may not lead to an error at runtime |
| 2354 return node |
| 2355 |
| 2356 args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:]) |
| 2357 args_tuple = args_tuple.analyse_types( |
| 2358 self.current_env(), skip_children=True) |
| 2359 |
| 2360 if type_arg.type_entry: |
| 2361 ext_type = type_arg.type_entry.type |
| 2362 if (ext_type.is_extension_type and ext_type.typeobj_cname and |
| 2363 ext_type.scope.global_scope() == self.current_env().global_s
cope()): |
| 2364 # known type in current module |
| 2365 tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__') |
| 2366 slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp
_slot) |
| 2367 if slot_func_cname: |
| 2368 cython_scope = self.context.cython_scope |
| 2369 PyTypeObjectPtr = PyrexTypes.CPtrType( |
| 2370 cython_scope.lookup('PyTypeObject').type) |
| 2371 pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType( |
| 2372 PyrexTypes.py_object_type, [ |
| 2373 PyrexTypes.CFuncTypeArg("type", PyTypeObjectPtr, N
one), |
| 2374 PyrexTypes.CFuncTypeArg("args", PyrexTypes.py_obje
ct_type, None), |
| 2375 PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_obje
ct_type, None), |
| 2376 ]) |
| 2377 |
| 2378 type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr) |
| 2379 if not kwargs: |
| 2380 kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py
_object_type) # hack? |
| 2381 return ExprNodes.PythonCapiCallNode( |
| 2382 node.pos, slot_func_cname, |
| 2383 pyx_tp_new_kwargs_func_type, |
| 2384 args=[type_arg, args_tuple, kwargs], |
| 2385 is_temp=True) |
| 2386 else: |
| 2387 # arbitrary variable, needs a None check for safety |
| 2388 type_arg = type_arg.as_none_safe_node( |
| 2389 "object.__new__(X): X is not a type object (NoneType)") |
| 2390 |
| 2391 utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c') |
| 2392 if kwargs: |
| 2393 return ExprNodes.PythonCapiCallNode( |
| 2394 node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_typ
e, |
| 2395 args=[type_arg, args_tuple, kwargs], |
| 2396 utility_code=utility_code, |
| 2397 is_temp=node.is_temp |
| 2398 ) |
| 2399 else: |
| 2400 return ExprNodes.PythonCapiCallNode( |
| 2401 node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type, |
| 2402 args=[type_arg, args_tuple], |
| 2403 utility_code=utility_code, |
| 2404 is_temp=node.is_temp |
| 2405 ) |
| 2406 |
| 2407 ### methods of builtin types |
| 2408 |
| 2409 PyObject_Append_func_type = PyrexTypes.CFuncType( |
| 2410 PyrexTypes.c_returncode_type, [ |
| 2411 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), |
| 2412 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None), |
| 2413 ], |
| 2414 exception_value="-1") |
| 2415 |
| 2416 def _handle_simple_method_object_append(self, node, function, args, is_unbou
nd_method): |
| 2417 """Optimistic optimisation as X.append() is almost always |
| 2418 referring to a list. |
| 2419 """ |
| 2420 if len(args) != 2 or node.result_is_used: |
| 2421 return node |
| 2422 |
| 2423 return ExprNodes.PythonCapiCallNode( |
| 2424 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type, |
| 2425 args=args, |
| 2426 may_return_none=False, |
| 2427 is_temp=node.is_temp, |
| 2428 result_is_used=False, |
| 2429 utility_code=load_c_utility('append') |
| 2430 ) |
| 2431 |
| 2432 PyByteArray_Append_func_type = PyrexTypes.CFuncType( |
| 2433 PyrexTypes.c_returncode_type, [ |
| 2434 PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None
), |
| 2435 PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None), |
| 2436 ], |
| 2437 exception_value="-1") |
| 2438 |
| 2439 PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType( |
| 2440 PyrexTypes.c_returncode_type, [ |
| 2441 PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None
), |
| 2442 PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None), |
| 2443 ], |
| 2444 exception_value="-1") |
| 2445 |
| 2446 def _handle_simple_method_bytearray_append(self, node, function, args, is_un
bound_method): |
| 2447 if len(args) != 2: |
| 2448 return node |
| 2449 func_name = "__Pyx_PyByteArray_Append" |
| 2450 func_type = self.PyByteArray_Append_func_type |
| 2451 |
| 2452 value = unwrap_coerced_node(args[1]) |
| 2453 if value.type.is_int or isinstance(value, ExprNodes.IntNode): |
| 2454 value = value.coerce_to(PyrexTypes.c_int_type, self.current_env()) |
| 2455 utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringToo
ls.c") |
| 2456 elif value.is_string_literal: |
| 2457 if not value.can_coerce_to_char_literal(): |
| 2458 return node |
| 2459 value = value.coerce_to(PyrexTypes.c_char_type, self.current_env()) |
| 2460 utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringToo
ls.c") |
| 2461 elif value.type.is_pyobject: |
| 2462 func_name = "__Pyx_PyByteArray_AppendObject" |
| 2463 func_type = self.PyByteArray_AppendObject_func_type |
| 2464 utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "Str
ingTools.c") |
| 2465 else: |
| 2466 return node |
| 2467 |
| 2468 new_node = ExprNodes.PythonCapiCallNode( |
| 2469 node.pos, func_name, func_type, |
| 2470 args=[args[0], value], |
| 2471 may_return_none=False, |
| 2472 is_temp=node.is_temp, |
| 2473 utility_code=utility_code, |
| 2474 ) |
| 2475 if node.result_is_used: |
| 2476 new_node = new_node.coerce_to(node.type, self.current_env()) |
| 2477 return new_node |
| 2478 |
| 2479 PyObject_Pop_func_type = PyrexTypes.CFuncType( |
| 2480 PyrexTypes.py_object_type, [ |
| 2481 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), |
| 2482 ]) |
| 2483 |
| 2484 PyObject_PopIndex_func_type = PyrexTypes.CFuncType( |
| 2485 PyrexTypes.py_object_type, [ |
| 2486 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), |
| 2487 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None), |
| 2488 ]) |
| 2489 |
| 2490 def _handle_simple_method_list_pop(self, node, function, args, is_unbound_me
thod): |
| 2491 return self._handle_simple_method_object_pop( |
| 2492 node, function, args, is_unbound_method, is_list=True) |
| 2493 |
| 2494 def _handle_simple_method_object_pop(self, node, function, args, is_unbound_
method, is_list=False): |
| 2495 """Optimistic optimisation as X.pop([n]) is almost always |
| 2496 referring to a list. |
| 2497 """ |
| 2498 if not args: |
| 2499 return node |
| 2500 args = args[:] |
| 2501 if is_list: |
| 2502 type_name = 'List' |
| 2503 args[0] = args[0].as_none_safe_node( |
| 2504 "'NoneType' object has no attribute '%s'", |
| 2505 error="PyExc_AttributeError", |
| 2506 format_args=['pop']) |
| 2507 else: |
| 2508 type_name = 'Object' |
| 2509 if len(args) == 1: |
| 2510 return ExprNodes.PythonCapiCallNode( |
| 2511 node.pos, "__Pyx_Py%s_Pop" % type_name, |
| 2512 self.PyObject_Pop_func_type, |
| 2513 args=args, |
| 2514 may_return_none=True, |
| 2515 is_temp=node.is_temp, |
| 2516 utility_code=load_c_utility('pop'), |
| 2517 ) |
| 2518 elif len(args) == 2: |
| 2519 index = unwrap_coerced_node(args[1]) |
| 2520 if is_list or isinstance(index, ExprNodes.IntNode): |
| 2521 index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.curre
nt_env()) |
| 2522 if index.type.is_int: |
| 2523 widest = PyrexTypes.widest_numeric_type( |
| 2524 index.type, PyrexTypes.c_py_ssize_t_type) |
| 2525 if widest == PyrexTypes.c_py_ssize_t_type: |
| 2526 args[1] = index |
| 2527 return ExprNodes.PythonCapiCallNode( |
| 2528 node.pos, "__Pyx_Py%s_PopIndex" % type_name, |
| 2529 self.PyObject_PopIndex_func_type, |
| 2530 args=args, |
| 2531 may_return_none=True, |
| 2532 is_temp=node.is_temp, |
| 2533 utility_code=load_c_utility("pop_index"), |
| 2534 ) |
| 2535 |
| 2536 return node |
| 2537 |
| 2538 single_param_func_type = PyrexTypes.CFuncType( |
| 2539 PyrexTypes.c_returncode_type, [ |
| 2540 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None), |
| 2541 ], |
| 2542 exception_value = "-1") |
| 2543 |
| 2544 def _handle_simple_method_list_sort(self, node, function, args, is_unbound_m
ethod): |
| 2545 """Call PyList_Sort() instead of the 0-argument l.sort(). |
| 2546 """ |
| 2547 if len(args) != 1: |
| 2548 return node |
| 2549 return self._substitute_method_call( |
| 2550 node, function, "PyList_Sort", self.single_param_func_type, |
| 2551 'sort', is_unbound_method, args).coerce_to(node.type, self.current_e
nv) |
| 2552 |
| 2553 Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType( |
| 2554 PyrexTypes.py_object_type, [ |
| 2555 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), |
| 2556 PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None), |
| 2557 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), |
| 2558 ]) |
| 2559 |
| 2560 def _handle_simple_method_dict_get(self, node, function, args, is_unbound_me
thod): |
| 2561 """Replace dict.get() by a call to PyDict_GetItem(). |
| 2562 """ |
| 2563 if len(args) == 2: |
| 2564 args.append(ExprNodes.NoneNode(node.pos)) |
| 2565 elif len(args) != 3: |
| 2566 self._error_wrong_arg_count('dict.get', node, args, "2 or 3") |
| 2567 return node |
| 2568 |
| 2569 return self._substitute_method_call( |
| 2570 node, function, |
| 2571 "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type, |
| 2572 'get', is_unbound_method, args, |
| 2573 may_return_none = True, |
| 2574 utility_code = load_c_utility("dict_getitem_default")) |
| 2575 |
| 2576 Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType( |
| 2577 PyrexTypes.py_object_type, [ |
| 2578 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), |
| 2579 PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None), |
| 2580 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), |
| 2581 PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None)
, |
| 2582 ]) |
| 2583 |
| 2584 def _handle_simple_method_dict_setdefault(self, node, function, args, is_unb
ound_method): |
| 2585 """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_Set
Item(). |
| 2586 """ |
| 2587 if len(args) == 2: |
| 2588 args.append(ExprNodes.NoneNode(node.pos)) |
| 2589 elif len(args) != 3: |
| 2590 self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3") |
| 2591 return node |
| 2592 key_type = args[1].type |
| 2593 if key_type.is_builtin_type: |
| 2594 is_safe_type = int(key_type.name in |
| 2595 'str bytes unicode float int long bool') |
| 2596 elif key_type is PyrexTypes.py_object_type: |
| 2597 is_safe_type = -1 # don't know |
| 2598 else: |
| 2599 is_safe_type = 0 # definitely not |
| 2600 args.append(ExprNodes.IntNode( |
| 2601 node.pos, value=str(is_safe_type), constant_result=is_safe_type)) |
| 2602 |
| 2603 return self._substitute_method_call( |
| 2604 node, function, |
| 2605 "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type, |
| 2606 'setdefault', is_unbound_method, args, |
| 2607 may_return_none=True, |
| 2608 utility_code=load_c_utility('dict_setdefault')) |
| 2609 |
| 2610 |
| 2611 ### unicode type methods |
| 2612 |
| 2613 PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType( |
| 2614 PyrexTypes.c_bint_type, [ |
| 2615 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None), |
| 2616 ]) |
| 2617 |
| 2618 def _inject_unicode_predicate(self, node, function, args, is_unbound_method)
: |
| 2619 if is_unbound_method or len(args) != 1: |
| 2620 return node |
| 2621 ustring = args[0] |
| 2622 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \ |
| 2623 not ustring.arg.type.is_unicode_char: |
| 2624 return node |
| 2625 uchar = ustring.arg |
| 2626 method_name = function.attribute |
| 2627 if method_name == 'istitle': |
| 2628 # istitle() doesn't directly map to Py_UNICODE_ISTITLE() |
| 2629 utility_code = UtilityCode.load_cached( |
| 2630 "py_unicode_istitle", "StringTools.c") |
| 2631 function_name = '__Pyx_Py_UNICODE_ISTITLE' |
| 2632 else: |
| 2633 utility_code = None |
| 2634 function_name = 'Py_UNICODE_%s' % method_name.upper() |
| 2635 func_call = self._substitute_method_call( |
| 2636 node, function, |
| 2637 function_name, self.PyUnicode_uchar_predicate_func_type, |
| 2638 method_name, is_unbound_method, [uchar], |
| 2639 utility_code = utility_code) |
| 2640 if node.type.is_pyobject: |
| 2641 func_call = func_call.coerce_to_pyobject(self.current_env) |
| 2642 return func_call |
| 2643 |
| 2644 _handle_simple_method_unicode_isalnum = _inject_unicode_predicate |
| 2645 _handle_simple_method_unicode_isalpha = _inject_unicode_predicate |
| 2646 _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate |
| 2647 _handle_simple_method_unicode_isdigit = _inject_unicode_predicate |
| 2648 _handle_simple_method_unicode_islower = _inject_unicode_predicate |
| 2649 _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate |
| 2650 _handle_simple_method_unicode_isspace = _inject_unicode_predicate |
| 2651 _handle_simple_method_unicode_istitle = _inject_unicode_predicate |
| 2652 _handle_simple_method_unicode_isupper = _inject_unicode_predicate |
| 2653 |
| 2654 PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType( |
| 2655 PyrexTypes.c_py_ucs4_type, [ |
| 2656 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None), |
| 2657 ]) |
| 2658 |
| 2659 def _inject_unicode_character_conversion(self, node, function, args, is_unbo
und_method): |
| 2660 if is_unbound_method or len(args) != 1: |
| 2661 return node |
| 2662 ustring = args[0] |
| 2663 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \ |
| 2664 not ustring.arg.type.is_unicode_char: |
| 2665 return node |
| 2666 uchar = ustring.arg |
| 2667 method_name = function.attribute |
| 2668 function_name = 'Py_UNICODE_TO%s' % method_name.upper() |
| 2669 func_call = self._substitute_method_call( |
| 2670 node, function, |
| 2671 function_name, self.PyUnicode_uchar_conversion_func_type, |
| 2672 method_name, is_unbound_method, [uchar]) |
| 2673 if node.type.is_pyobject: |
| 2674 func_call = func_call.coerce_to_pyobject(self.current_env) |
| 2675 return func_call |
| 2676 |
| 2677 _handle_simple_method_unicode_lower = _inject_unicode_character_conversion |
| 2678 _handle_simple_method_unicode_upper = _inject_unicode_character_conversion |
| 2679 _handle_simple_method_unicode_title = _inject_unicode_character_conversion |
| 2680 |
| 2681 PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType( |
| 2682 Builtin.list_type, [ |
| 2683 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), |
| 2684 PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None), |
| 2685 ]) |
| 2686 |
| 2687 def _handle_simple_method_unicode_splitlines(self, node, function, args, is_
unbound_method): |
| 2688 """Replace unicode.splitlines(...) by a direct call to the |
| 2689 corresponding C-API function. |
| 2690 """ |
| 2691 if len(args) not in (1,2): |
| 2692 self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or
2") |
| 2693 return node |
| 2694 self._inject_bint_default_argument(node, args, 1, False) |
| 2695 |
| 2696 return self._substitute_method_call( |
| 2697 node, function, |
| 2698 "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type, |
| 2699 'splitlines', is_unbound_method, args) |
| 2700 |
| 2701 PyUnicode_Split_func_type = PyrexTypes.CFuncType( |
| 2702 Builtin.list_type, [ |
| 2703 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), |
| 2704 PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None), |
| 2705 PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, No
ne), |
| 2706 ] |
| 2707 ) |
| 2708 |
| 2709 def _handle_simple_method_unicode_split(self, node, function, args, is_unbou
nd_method): |
| 2710 """Replace unicode.split(...) by a direct call to the |
| 2711 corresponding C-API function. |
| 2712 """ |
| 2713 if len(args) not in (1,2,3): |
| 2714 self._error_wrong_arg_count('unicode.split', node, args, "1-3") |
| 2715 return node |
| 2716 if len(args) < 2: |
| 2717 args.append(ExprNodes.NullNode(node.pos)) |
| 2718 self._inject_int_default_argument( |
| 2719 node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1") |
| 2720 |
| 2721 return self._substitute_method_call( |
| 2722 node, function, |
| 2723 "PyUnicode_Split", self.PyUnicode_Split_func_type, |
| 2724 'split', is_unbound_method, args) |
| 2725 |
| 2726 PyString_Tailmatch_func_type = PyrexTypes.CFuncType( |
| 2727 PyrexTypes.c_bint_type, [ |
| 2728 PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), #
bytes/str/unicode |
| 2729 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None
), |
| 2730 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None)
, |
| 2731 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), |
| 2732 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None), |
| 2733 ], |
| 2734 exception_value = '-1') |
| 2735 |
| 2736 def _handle_simple_method_unicode_endswith(self, node, function, args, is_un
bound_method): |
| 2737 return self._inject_tailmatch( |
| 2738 node, function, args, is_unbound_method, 'unicode', 'endswith', |
| 2739 unicode_tailmatch_utility_code, +1) |
| 2740 |
| 2741 def _handle_simple_method_unicode_startswith(self, node, function, args, is_
unbound_method): |
| 2742 return self._inject_tailmatch( |
| 2743 node, function, args, is_unbound_method, 'unicode', 'startswith', |
| 2744 unicode_tailmatch_utility_code, -1) |
| 2745 |
| 2746 def _inject_tailmatch(self, node, function, args, is_unbound_method, type_na
me, |
| 2747 method_name, utility_code, direction): |
| 2748 """Replace unicode.startswith(...) and unicode.endswith(...) |
| 2749 by a direct call to the corresponding C-API function. |
| 2750 """ |
| 2751 if len(args) not in (2,3,4): |
| 2752 self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node
, args, "2-4") |
| 2753 return node |
| 2754 self._inject_int_default_argument( |
| 2755 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") |
| 2756 self._inject_int_default_argument( |
| 2757 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") |
| 2758 args.append(ExprNodes.IntNode( |
| 2759 node.pos, value=str(direction), type=PyrexTypes.c_int_type)) |
| 2760 |
| 2761 method_call = self._substitute_method_call( |
| 2762 node, function, |
| 2763 "__Pyx_Py%s_Tailmatch" % type_name.capitalize(), |
| 2764 self.PyString_Tailmatch_func_type, |
| 2765 method_name, is_unbound_method, args, |
| 2766 utility_code = utility_code) |
| 2767 return method_call.coerce_to(Builtin.bool_type, self.current_env()) |
| 2768 |
| 2769 PyUnicode_Find_func_type = PyrexTypes.CFuncType( |
| 2770 PyrexTypes.c_py_ssize_t_type, [ |
| 2771 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), |
| 2772 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None
), |
| 2773 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None)
, |
| 2774 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), |
| 2775 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None), |
| 2776 ], |
| 2777 exception_value = '-2') |
| 2778 |
| 2779 def _handle_simple_method_unicode_find(self, node, function, args, is_unboun
d_method): |
| 2780 return self._inject_unicode_find( |
| 2781 node, function, args, is_unbound_method, 'find', +1) |
| 2782 |
| 2783 def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbou
nd_method): |
| 2784 return self._inject_unicode_find( |
| 2785 node, function, args, is_unbound_method, 'rfind', -1) |
| 2786 |
| 2787 def _inject_unicode_find(self, node, function, args, is_unbound_method, |
| 2788 method_name, direction): |
| 2789 """Replace unicode.find(...) and unicode.rfind(...) by a |
| 2790 direct call to the corresponding C-API function. |
| 2791 """ |
| 2792 if len(args) not in (2,3,4): |
| 2793 self._error_wrong_arg_count('unicode.%s' % method_name, node, args,
"2-4") |
| 2794 return node |
| 2795 self._inject_int_default_argument( |
| 2796 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") |
| 2797 self._inject_int_default_argument( |
| 2798 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") |
| 2799 args.append(ExprNodes.IntNode( |
| 2800 node.pos, value=str(direction), type=PyrexTypes.c_int_type)) |
| 2801 |
| 2802 method_call = self._substitute_method_call( |
| 2803 node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type, |
| 2804 method_name, is_unbound_method, args) |
| 2805 return method_call.coerce_to_pyobject(self.current_env()) |
| 2806 |
| 2807 PyUnicode_Count_func_type = PyrexTypes.CFuncType( |
| 2808 PyrexTypes.c_py_ssize_t_type, [ |
| 2809 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), |
| 2810 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None
), |
| 2811 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None)
, |
| 2812 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), |
| 2813 ], |
| 2814 exception_value = '-1') |
| 2815 |
| 2816 def _handle_simple_method_unicode_count(self, node, function, args, is_unbou
nd_method): |
| 2817 """Replace unicode.count(...) by a direct call to the |
| 2818 corresponding C-API function. |
| 2819 """ |
| 2820 if len(args) not in (2,3,4): |
| 2821 self._error_wrong_arg_count('unicode.count', node, args, "2-4") |
| 2822 return node |
| 2823 self._inject_int_default_argument( |
| 2824 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") |
| 2825 self._inject_int_default_argument( |
| 2826 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") |
| 2827 |
| 2828 method_call = self._substitute_method_call( |
| 2829 node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type, |
| 2830 'count', is_unbound_method, args) |
| 2831 return method_call.coerce_to_pyobject(self.current_env()) |
| 2832 |
| 2833 PyUnicode_Replace_func_type = PyrexTypes.CFuncType( |
| 2834 Builtin.unicode_type, [ |
| 2835 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), |
| 2836 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None
), |
| 2837 PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None), |
| 2838 PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, No
ne), |
| 2839 ]) |
| 2840 |
| 2841 def _handle_simple_method_unicode_replace(self, node, function, args, is_unb
ound_method): |
| 2842 """Replace unicode.replace(...) by a direct call to the |
| 2843 corresponding C-API function. |
| 2844 """ |
| 2845 if len(args) not in (3,4): |
| 2846 self._error_wrong_arg_count('unicode.replace', node, args, "3-4") |
| 2847 return node |
| 2848 self._inject_int_default_argument( |
| 2849 node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1") |
| 2850 |
| 2851 return self._substitute_method_call( |
| 2852 node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_typ
e, |
| 2853 'replace', is_unbound_method, args) |
| 2854 |
| 2855 PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType( |
| 2856 Builtin.bytes_type, [ |
| 2857 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), |
| 2858 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None
), |
| 2859 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), |
| 2860 ]) |
| 2861 |
| 2862 PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType( |
| 2863 Builtin.bytes_type, [ |
| 2864 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), |
| 2865 ]) |
| 2866 |
| 2867 _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII', |
| 2868 'unicode_escape', 'raw_unicode_escape'] |
| 2869 |
| 2870 _special_codecs = [ (name, codecs.getencoder(name)) |
| 2871 for name in _special_encodings ] |
| 2872 |
| 2873 def _handle_simple_method_unicode_encode(self, node, function, args, is_unbo
und_method): |
| 2874 """Replace unicode.encode(...) by a direct C-API call to the |
| 2875 corresponding codec. |
| 2876 """ |
| 2877 if len(args) < 1 or len(args) > 3: |
| 2878 self._error_wrong_arg_count('unicode.encode', node, args, '1-3') |
| 2879 return node |
| 2880 |
| 2881 string_node = args[0] |
| 2882 |
| 2883 if len(args) == 1: |
| 2884 null_node = ExprNodes.NullNode(node.pos) |
| 2885 return self._substitute_method_call( |
| 2886 node, function, "PyUnicode_AsEncodedString", |
| 2887 self.PyUnicode_AsEncodedString_func_type, |
| 2888 'encode', is_unbound_method, [string_node, null_node, null_node]
) |
| 2889 |
| 2890 parameters = self._unpack_encoding_and_error_mode(node.pos, args) |
| 2891 if parameters is None: |
| 2892 return node |
| 2893 encoding, encoding_node, error_handling, error_handling_node = parameter
s |
| 2894 |
| 2895 if encoding and isinstance(string_node, ExprNodes.UnicodeNode): |
| 2896 # constant, so try to do the encoding at compile time |
| 2897 try: |
| 2898 value = string_node.value.encode(encoding, error_handling) |
| 2899 except: |
| 2900 # well, looks like we can't |
| 2901 pass |
| 2902 else: |
| 2903 value = BytesLiteral(value) |
| 2904 value.encoding = encoding |
| 2905 return ExprNodes.BytesNode( |
| 2906 string_node.pos, value=value, type=Builtin.bytes_type) |
| 2907 |
| 2908 if encoding and error_handling == 'strict': |
| 2909 # try to find a specific encoder function |
| 2910 codec_name = self._find_special_codec_name(encoding) |
| 2911 if codec_name is not None: |
| 2912 encode_function = "PyUnicode_As%sString" % codec_name |
| 2913 return self._substitute_method_call( |
| 2914 node, function, encode_function, |
| 2915 self.PyUnicode_AsXyzString_func_type, |
| 2916 'encode', is_unbound_method, [string_node]) |
| 2917 |
| 2918 return self._substitute_method_call( |
| 2919 node, function, "PyUnicode_AsEncodedString", |
| 2920 self.PyUnicode_AsEncodedString_func_type, |
| 2921 'encode', is_unbound_method, |
| 2922 [string_node, encoding_node, error_handling_node]) |
| 2923 |
| 2924 PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType
( |
| 2925 Builtin.unicode_type, [ |
| 2926 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None), |
| 2927 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None), |
| 2928 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), |
| 2929 ])) |
| 2930 |
| 2931 _decode_c_string_func_type = PyrexTypes.CFuncType( |
| 2932 Builtin.unicode_type, [ |
| 2933 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None), |
| 2934 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None)
, |
| 2935 PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), |
| 2936 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None
), |
| 2937 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), |
| 2938 PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_
type, None), |
| 2939 ]) |
| 2940 |
| 2941 _decode_bytes_func_type = PyrexTypes.CFuncType( |
| 2942 Builtin.unicode_type, [ |
| 2943 PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None), |
| 2944 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None)
, |
| 2945 PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), |
| 2946 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None
), |
| 2947 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), |
| 2948 PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_
type, None), |
| 2949 ]) |
| 2950 |
| 2951 _decode_cpp_string_func_type = None # lazy init |
| 2952 |
| 2953 def _handle_simple_method_bytes_decode(self, node, function, args, is_unboun
d_method): |
| 2954 """Replace char*.decode() by a direct C-API call to the |
| 2955 corresponding codec, possibly resolving a slice on the char*. |
| 2956 """ |
| 2957 if not (1 <= len(args) <= 3): |
| 2958 self._error_wrong_arg_count('bytes.decode', node, args, '1-3') |
| 2959 return node |
| 2960 |
| 2961 # normalise input nodes |
| 2962 string_node = args[0] |
| 2963 start = stop = None |
| 2964 if isinstance(string_node, ExprNodes.SliceIndexNode): |
| 2965 index_node = string_node |
| 2966 string_node = index_node.base |
| 2967 start, stop = index_node.start, index_node.stop |
| 2968 if not start or start.constant_result == 0: |
| 2969 start = None |
| 2970 if isinstance(string_node, ExprNodes.CoerceToPyTypeNode): |
| 2971 string_node = string_node.arg |
| 2972 |
| 2973 string_type = string_node.type |
| 2974 if string_type in (Builtin.bytes_type, Builtin.bytearray_type): |
| 2975 if is_unbound_method: |
| 2976 string_node = string_node.as_none_safe_node( |
| 2977 "descriptor '%s' requires a '%s' object but received a 'None
Type'", |
| 2978 format_args=['decode', string_type.name]) |
| 2979 else: |
| 2980 string_node = string_node.as_none_safe_node( |
| 2981 "'NoneType' object has no attribute '%s'", |
| 2982 error="PyExc_AttributeError", |
| 2983 format_args=['decode']) |
| 2984 elif not string_type.is_string and not string_type.is_cpp_string: |
| 2985 # nothing to optimise here |
| 2986 return node |
| 2987 |
| 2988 parameters = self._unpack_encoding_and_error_mode(node.pos, args) |
| 2989 if parameters is None: |
| 2990 return node |
| 2991 encoding, encoding_node, error_handling, error_handling_node = parameter
s |
| 2992 |
| 2993 if not start: |
| 2994 start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) |
| 2995 elif not start.type.is_int: |
| 2996 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_e
nv()) |
| 2997 if stop and not stop.type.is_int: |
| 2998 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env
()) |
| 2999 |
| 3000 # try to find a specific encoder function |
| 3001 codec_name = None |
| 3002 if encoding is not None: |
| 3003 codec_name = self._find_special_codec_name(encoding) |
| 3004 if codec_name is not None: |
| 3005 decode_function = ExprNodes.RawCNameExprNode( |
| 3006 node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type, |
| 3007 cname="PyUnicode_Decode%s" % codec_name) |
| 3008 encoding_node = ExprNodes.NullNode(node.pos) |
| 3009 else: |
| 3010 decode_function = ExprNodes.NullNode(node.pos) |
| 3011 |
| 3012 # build the helper function call |
| 3013 temps = [] |
| 3014 if string_type.is_string: |
| 3015 # C string |
| 3016 if not stop: |
| 3017 # use strlen() to find the string length, just as CPython would |
| 3018 if not string_node.is_name: |
| 3019 string_node = UtilNodes.LetRefNode(string_node) # used twice |
| 3020 temps.append(string_node) |
| 3021 stop = ExprNodes.PythonCapiCallNode( |
| 3022 string_node.pos, "strlen", self.Pyx_strlen_func_type, |
| 3023 args=[string_node], |
| 3024 is_temp=False, |
| 3025 utility_code=UtilityCode.load_cached("IncludeStringH", "Stri
ngTools.c"), |
| 3026 ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) |
| 3027 helper_func_type = self._decode_c_string_func_type |
| 3028 utility_code_name = 'decode_c_string' |
| 3029 elif string_type.is_cpp_string: |
| 3030 # C++ std::string |
| 3031 if not stop: |
| 3032 stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX', |
| 3033 constant_result=ExprNodes.not_a_constan
t) |
| 3034 if self._decode_cpp_string_func_type is None: |
| 3035 # lazy init to reuse the C++ string type |
| 3036 self._decode_cpp_string_func_type = PyrexTypes.CFuncType( |
| 3037 Builtin.unicode_type, [ |
| 3038 PyrexTypes.CFuncTypeArg("string", string_type, None), |
| 3039 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t
_type, None), |
| 3040 PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_
type, None), |
| 3041 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_pt
r_type, None), |
| 3042 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_
type, None), |
| 3043 PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_De
codeXyz_func_ptr_type, None), |
| 3044 ]) |
| 3045 helper_func_type = self._decode_cpp_string_func_type |
| 3046 utility_code_name = 'decode_cpp_string' |
| 3047 else: |
| 3048 # Python bytes/bytearray object |
| 3049 if not stop: |
| 3050 stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX', |
| 3051 constant_result=ExprNodes.not_a_constan
t) |
| 3052 helper_func_type = self._decode_bytes_func_type |
| 3053 if string_type is Builtin.bytes_type: |
| 3054 utility_code_name = 'decode_bytes' |
| 3055 else: |
| 3056 utility_code_name = 'decode_bytearray' |
| 3057 |
| 3058 node = ExprNodes.PythonCapiCallNode( |
| 3059 node.pos, '__Pyx_%s' % utility_code_name, helper_func_type, |
| 3060 args=[string_node, start, stop, encoding_node, error_handling_node,
decode_function], |
| 3061 is_temp=node.is_temp, |
| 3062 utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools
.c'), |
| 3063 ) |
| 3064 |
| 3065 for temp in temps[::-1]: |
| 3066 node = UtilNodes.EvalWithTempExprNode(temp, node) |
| 3067 return node |
| 3068 |
| 3069 _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode |
| 3070 |
| 3071 def _find_special_codec_name(self, encoding): |
| 3072 try: |
| 3073 requested_codec = codecs.getencoder(encoding) |
| 3074 except LookupError: |
| 3075 return None |
| 3076 for name, codec in self._special_codecs: |
| 3077 if codec == requested_codec: |
| 3078 if '_' in name: |
| 3079 name = ''.join([s.capitalize() |
| 3080 for s in name.split('_')]) |
| 3081 return name |
| 3082 return None |
| 3083 |
| 3084 def _unpack_encoding_and_error_mode(self, pos, args): |
| 3085 null_node = ExprNodes.NullNode(pos) |
| 3086 |
| 3087 if len(args) >= 2: |
| 3088 encoding, encoding_node = self._unpack_string_and_cstring_node(args[
1]) |
| 3089 if encoding_node is None: |
| 3090 return None |
| 3091 else: |
| 3092 encoding = None |
| 3093 encoding_node = null_node |
| 3094 |
| 3095 if len(args) == 3: |
| 3096 error_handling, error_handling_node = self._unpack_string_and_cstrin
g_node(args[2]) |
| 3097 if error_handling_node is None: |
| 3098 return None |
| 3099 if error_handling == 'strict': |
| 3100 error_handling_node = null_node |
| 3101 else: |
| 3102 error_handling = 'strict' |
| 3103 error_handling_node = null_node |
| 3104 |
| 3105 return (encoding, encoding_node, error_handling, error_handling_node) |
| 3106 |
| 3107 def _unpack_string_and_cstring_node(self, node): |
| 3108 if isinstance(node, ExprNodes.CoerceToPyTypeNode): |
| 3109 node = node.arg |
| 3110 if isinstance(node, ExprNodes.UnicodeNode): |
| 3111 encoding = node.value |
| 3112 node = ExprNodes.BytesNode( |
| 3113 node.pos, value=BytesLiteral(encoding.utf8encode()), |
| 3114 type=PyrexTypes.c_char_ptr_type) |
| 3115 elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)): |
| 3116 encoding = node.value.decode('ISO-8859-1') |
| 3117 node = ExprNodes.BytesNode( |
| 3118 node.pos, value=node.value, type=PyrexTypes.c_char_ptr_type) |
| 3119 elif node.type is Builtin.bytes_type: |
| 3120 encoding = None |
| 3121 node = node.coerce_to(PyrexTypes.c_char_ptr_type, self.current_env()
) |
| 3122 elif node.type.is_string: |
| 3123 encoding = None |
| 3124 else: |
| 3125 encoding = node = None |
| 3126 return encoding, node |
| 3127 |
| 3128 def _handle_simple_method_str_endswith(self, node, function, args, is_unboun
d_method): |
| 3129 return self._inject_tailmatch( |
| 3130 node, function, args, is_unbound_method, 'str', 'endswith', |
| 3131 str_tailmatch_utility_code, +1) |
| 3132 |
| 3133 def _handle_simple_method_str_startswith(self, node, function, args, is_unbo
und_method): |
| 3134 return self._inject_tailmatch( |
| 3135 node, function, args, is_unbound_method, 'str', 'startswith', |
| 3136 str_tailmatch_utility_code, -1) |
| 3137 |
| 3138 def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbo
und_method): |
| 3139 return self._inject_tailmatch( |
| 3140 node, function, args, is_unbound_method, 'bytes', 'endswith', |
| 3141 bytes_tailmatch_utility_code, +1) |
| 3142 |
| 3143 def _handle_simple_method_bytes_startswith(self, node, function, args, is_un
bound_method): |
| 3144 return self._inject_tailmatch( |
| 3145 node, function, args, is_unbound_method, 'bytes', 'startswith', |
| 3146 bytes_tailmatch_utility_code, -1) |
| 3147 |
| 3148 ''' # disabled for now, enable when we consider it worth it (see StringToo
ls.c) |
| 3149 def _handle_simple_method_bytearray_endswith(self, node, function, args, is_
unbound_method): |
| 3150 return self._inject_tailmatch( |
| 3151 node, function, args, is_unbound_method, 'bytearray', 'endswith', |
| 3152 bytes_tailmatch_utility_code, +1) |
| 3153 |
| 3154 def _handle_simple_method_bytearray_startswith(self, node, function, args, i
s_unbound_method): |
| 3155 return self._inject_tailmatch( |
| 3156 node, function, args, is_unbound_method, 'bytearray', 'startswith', |
| 3157 bytes_tailmatch_utility_code, -1) |
| 3158 ''' |
| 3159 |
| 3160 ### helpers |
| 3161 |
| 3162 def _substitute_method_call(self, node, function, name, func_type, |
| 3163 attr_name, is_unbound_method, args=(), |
| 3164 utility_code=None, is_temp=None, |
| 3165 may_return_none=ExprNodes.PythonCapiCallNode.may
_return_none): |
| 3166 args = list(args) |
| 3167 if args and not args[0].is_literal: |
| 3168 self_arg = args[0] |
| 3169 if is_unbound_method: |
| 3170 self_arg = self_arg.as_none_safe_node( |
| 3171 "descriptor '%s' requires a '%s' object but received a 'None
Type'", |
| 3172 format_args=[attr_name, function.obj.name]) |
| 3173 else: |
| 3174 self_arg = self_arg.as_none_safe_node( |
| 3175 "'NoneType' object has no attribute '%s'", |
| 3176 error = "PyExc_AttributeError", |
| 3177 format_args = [attr_name]) |
| 3178 args[0] = self_arg |
| 3179 if is_temp is None: |
| 3180 is_temp = node.is_temp |
| 3181 return ExprNodes.PythonCapiCallNode( |
| 3182 node.pos, name, func_type, |
| 3183 args = args, |
| 3184 is_temp = is_temp, |
| 3185 utility_code = utility_code, |
| 3186 may_return_none = may_return_none, |
| 3187 result_is_used = node.result_is_used, |
| 3188 ) |
| 3189 |
| 3190 def _inject_int_default_argument(self, node, args, arg_index, type, default_
value): |
| 3191 assert len(args) >= arg_index |
| 3192 if len(args) == arg_index: |
| 3193 args.append(ExprNodes.IntNode(node.pos, value=str(default_value), |
| 3194 type=type, constant_result=default_val
ue)) |
| 3195 else: |
| 3196 args[arg_index] = args[arg_index].coerce_to(type, self.current_env()
) |
| 3197 |
| 3198 def _inject_bint_default_argument(self, node, args, arg_index, default_value
): |
| 3199 assert len(args) >= arg_index |
| 3200 if len(args) == arg_index: |
| 3201 default_value = bool(default_value) |
| 3202 args.append(ExprNodes.BoolNode(node.pos, value=default_value, |
| 3203 constant_result=default_value)) |
| 3204 else: |
| 3205 args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env
()) |
| 3206 |
| 3207 |
| 3208 unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'S
tringTools.c') |
| 3209 bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'Strin
gTools.c') |
| 3210 str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringToo
ls.c') |
| 3211 |
| 3212 |
| 3213 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): |
| 3214 """Calculate the result of constant expressions to store it in |
| 3215 ``expr_node.constant_result``, and replace trivial cases by their |
| 3216 constant result. |
| 3217 |
| 3218 General rules: |
| 3219 |
| 3220 - We calculate float constants to make them available to the |
| 3221 compiler, but we do not aggregate them into a single literal |
| 3222 node to prevent any loss of precision. |
| 3223 |
| 3224 - We recursively calculate constants from non-literal nodes to |
| 3225 make them available to the compiler, but we only aggregate |
| 3226 literal nodes at each step. Non-literal nodes are never merged |
| 3227 into a single node. |
| 3228 """ |
| 3229 |
| 3230 def __init__(self, reevaluate=False): |
| 3231 """ |
| 3232 The reevaluate argument specifies whether constant values that were |
| 3233 previously computed should be recomputed. |
| 3234 """ |
| 3235 super(ConstantFolding, self).__init__() |
| 3236 self.reevaluate = reevaluate |
| 3237 |
| 3238 def _calculate_const(self, node): |
| 3239 if (not self.reevaluate and |
| 3240 node.constant_result is not ExprNodes.constant_value_not_set): |
| 3241 return |
| 3242 |
| 3243 # make sure we always set the value |
| 3244 not_a_constant = ExprNodes.not_a_constant |
| 3245 node.constant_result = not_a_constant |
| 3246 |
| 3247 # check if all children are constant |
| 3248 children = self.visitchildren(node) |
| 3249 for child_result in children.values(): |
| 3250 if type(child_result) is list: |
| 3251 for child in child_result: |
| 3252 if getattr(child, 'constant_result', not_a_constant) is not_
a_constant: |
| 3253 return |
| 3254 elif getattr(child_result, 'constant_result', not_a_constant) is not
_a_constant: |
| 3255 return |
| 3256 |
| 3257 # now try to calculate the real constant value |
| 3258 try: |
| 3259 node.calculate_constant_result() |
| 3260 # if node.constant_result is not ExprNodes.not_a_constant: |
| 3261 # print node.__class__.__name__, node.constant_result |
| 3262 except (ValueError, TypeError, KeyError, IndexError, AttributeError, Ari
thmeticError): |
| 3263 # ignore all 'normal' errors here => no constant result |
| 3264 pass |
| 3265 except Exception: |
| 3266 # this looks like a real error |
| 3267 import traceback, sys |
| 3268 traceback.print_exc(file=sys.stdout) |
| 3269 |
| 3270 NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode, |
| 3271 ExprNodes.IntNode, ExprNodes.FloatNode] |
| 3272 |
| 3273 def _widest_node_class(self, *nodes): |
| 3274 try: |
| 3275 return self.NODE_TYPE_ORDER[ |
| 3276 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))] |
| 3277 except ValueError: |
| 3278 return None |
| 3279 |
| 3280 def _bool_node(self, node, value): |
| 3281 value = bool(value) |
| 3282 return ExprNodes.BoolNode(node.pos, value=value, constant_result=value) |
| 3283 |
| 3284 def visit_ExprNode(self, node): |
| 3285 self._calculate_const(node) |
| 3286 return node |
| 3287 |
| 3288 def visit_UnopNode(self, node): |
| 3289 self._calculate_const(node) |
| 3290 if not node.has_constant_result(): |
| 3291 if node.operator == '!': |
| 3292 return self._handle_NotNode(node) |
| 3293 return node |
| 3294 if not node.operand.is_literal: |
| 3295 return node |
| 3296 if node.operator == '!': |
| 3297 return self._bool_node(node, node.constant_result) |
| 3298 elif isinstance(node.operand, ExprNodes.BoolNode): |
| 3299 return ExprNodes.IntNode(node.pos, value=str(int(node.constant_resul
t)), |
| 3300 type=PyrexTypes.c_int_type, |
| 3301 constant_result=int(node.constant_result)) |
| 3302 elif node.operator == '+': |
| 3303 return self._handle_UnaryPlusNode(node) |
| 3304 elif node.operator == '-': |
| 3305 return self._handle_UnaryMinusNode(node) |
| 3306 return node |
| 3307 |
| 3308 _negate_operator = { |
| 3309 'in': 'not_in', |
| 3310 'not_in': 'in', |
| 3311 'is': 'is_not', |
| 3312 'is_not': 'is' |
| 3313 }.get |
| 3314 |
| 3315 def _handle_NotNode(self, node): |
| 3316 operand = node.operand |
| 3317 if isinstance(operand, ExprNodes.PrimaryCmpNode): |
| 3318 operator = self._negate_operator(operand.operator) |
| 3319 if operator: |
| 3320 node = copy.copy(operand) |
| 3321 node.operator = operator |
| 3322 node = self.visit_PrimaryCmpNode(node) |
| 3323 return node |
| 3324 |
| 3325 def _handle_UnaryMinusNode(self, node): |
| 3326 def _negate(value): |
| 3327 if value.startswith('-'): |
| 3328 value = value[1:] |
| 3329 else: |
| 3330 value = '-' + value |
| 3331 return value |
| 3332 |
| 3333 node_type = node.operand.type |
| 3334 if isinstance(node.operand, ExprNodes.FloatNode): |
| 3335 # this is a safe operation |
| 3336 return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.valu
e), |
| 3337 type=node_type, |
| 3338 constant_result=node.constant_result) |
| 3339 if node_type.is_int and node_type.signed or \ |
| 3340 isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyo
bject: |
| 3341 return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value)
, |
| 3342 type=node_type, |
| 3343 longness=node.operand.longness, |
| 3344 constant_result=node.constant_result) |
| 3345 return node |
| 3346 |
| 3347 def _handle_UnaryPlusNode(self, node): |
| 3348 if (node.operand.has_constant_result() and |
| 3349 node.constant_result == node.operand.constant_result): |
| 3350 return node.operand |
| 3351 return node |
| 3352 |
| 3353 def visit_BoolBinopNode(self, node): |
| 3354 self._calculate_const(node) |
| 3355 if not node.operand1.has_constant_result(): |
| 3356 return node |
| 3357 if node.operand1.constant_result: |
| 3358 if node.operator == 'and': |
| 3359 return node.operand2 |
| 3360 else: |
| 3361 return node.operand1 |
| 3362 else: |
| 3363 if node.operator == 'and': |
| 3364 return node.operand1 |
| 3365 else: |
| 3366 return node.operand2 |
| 3367 |
| 3368 def visit_BinopNode(self, node): |
| 3369 self._calculate_const(node) |
| 3370 if node.constant_result is ExprNodes.not_a_constant: |
| 3371 return node |
| 3372 if isinstance(node.constant_result, float): |
| 3373 return node |
| 3374 operand1, operand2 = node.operand1, node.operand2 |
| 3375 if not operand1.is_literal or not operand2.is_literal: |
| 3376 return node |
| 3377 |
| 3378 # now inject a new constant node with the calculated value |
| 3379 try: |
| 3380 type1, type2 = operand1.type, operand2.type |
| 3381 if type1 is None or type2 is None: |
| 3382 return node |
| 3383 except AttributeError: |
| 3384 return node |
| 3385 |
| 3386 if type1.is_numeric and type2.is_numeric: |
| 3387 widest_type = PyrexTypes.widest_numeric_type(type1, type2) |
| 3388 else: |
| 3389 widest_type = PyrexTypes.py_object_type |
| 3390 |
| 3391 target_class = self._widest_node_class(operand1, operand2) |
| 3392 if target_class is None: |
| 3393 return node |
| 3394 elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>
>': |
| 3395 # C arithmetic results in at least an int type |
| 3396 target_class = ExprNodes.IntNode |
| 3397 elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>
>&|^': |
| 3398 # C arithmetic results in at least an int type |
| 3399 target_class = ExprNodes.IntNode |
| 3400 |
| 3401 if target_class is ExprNodes.IntNode: |
| 3402 unsigned = getattr(operand1, 'unsigned', '') and \ |
| 3403 getattr(operand2, 'unsigned', '') |
| 3404 longness = "LL"[:max(len(getattr(operand1, 'longness', '')), |
| 3405 len(getattr(operand2, 'longness', '')))] |
| 3406 new_node = ExprNodes.IntNode(pos=node.pos, |
| 3407 unsigned=unsigned, longness=longness, |
| 3408 value=str(int(node.constant_result)), |
| 3409 constant_result=int(node.constant_resul
t)) |
| 3410 # IntNode is smart about the type it chooses, so we just |
| 3411 # make sure we were not smarter this time |
| 3412 if widest_type.is_pyobject or new_node.type.is_pyobject: |
| 3413 new_node.type = PyrexTypes.py_object_type |
| 3414 else: |
| 3415 new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_
node.type) |
| 3416 else: |
| 3417 if target_class is ExprNodes.BoolNode: |
| 3418 node_value = node.constant_result |
| 3419 else: |
| 3420 node_value = str(node.constant_result) |
| 3421 new_node = target_class(pos=node.pos, type = widest_type, |
| 3422 value = node_value, |
| 3423 constant_result = node.constant_result) |
| 3424 return new_node |
| 3425 |
| 3426 def visit_MulNode(self, node): |
| 3427 self._calculate_const(node) |
| 3428 if node.operand1.is_sequence_constructor: |
| 3429 return self._calculate_constant_seq(node, node.operand1, node.operan
d2) |
| 3430 if isinstance(node.operand1, ExprNodes.IntNode) and \ |
| 3431 node.operand2.is_sequence_constructor: |
| 3432 return self._calculate_constant_seq(node, node.operand2, node.operan
d1) |
| 3433 return self.visit_BinopNode(node) |
| 3434 |
| 3435 def _calculate_constant_seq(self, node, sequence_node, factor): |
| 3436 if factor.constant_result != 1 and sequence_node.args: |
| 3437 if isinstance(factor.constant_result, (int, long)) and factor.consta
nt_result <= 0: |
| 3438 del sequence_node.args[:] |
| 3439 sequence_node.mult_factor = None |
| 3440 elif sequence_node.mult_factor is not None: |
| 3441 if (isinstance(factor.constant_result, (int, long)) and |
| 3442 isinstance(sequence_node.mult_factor.constant_result, (i
nt, long))): |
| 3443 value = sequence_node.mult_factor.constant_result * factor.c
onstant_result |
| 3444 sequence_node.mult_factor = ExprNodes.IntNode( |
| 3445 sequence_node.mult_factor.pos, |
| 3446 value=str(value), constant_result=value) |
| 3447 else: |
| 3448 # don't know if we can combine the factors, so don't |
| 3449 return self.visit_BinopNode(node) |
| 3450 else: |
| 3451 sequence_node.mult_factor = factor |
| 3452 return sequence_node |
| 3453 |
| 3454 def visit_PrimaryCmpNode(self, node): |
| 3455 # calculate constant partial results in the comparison cascade |
| 3456 self.visitchildren(node, ['operand1']) |
| 3457 left_node = node.operand1 |
| 3458 cmp_node = node |
| 3459 while cmp_node is not None: |
| 3460 self.visitchildren(cmp_node, ['operand2']) |
| 3461 right_node = cmp_node.operand2 |
| 3462 cmp_node.constant_result = not_a_constant |
| 3463 if left_node.has_constant_result() and right_node.has_constant_resul
t(): |
| 3464 try: |
| 3465 cmp_node.calculate_cascaded_constant_result(left_node.consta
nt_result) |
| 3466 except (ValueError, TypeError, KeyError, IndexError, AttributeEr
ror, ArithmeticError): |
| 3467 pass # ignore all 'normal' errors here => no constant resul
t |
| 3468 left_node = right_node |
| 3469 cmp_node = cmp_node.cascade |
| 3470 |
| 3471 if not node.cascade: |
| 3472 if node.has_constant_result(): |
| 3473 return self._bool_node(node, node.constant_result) |
| 3474 return node |
| 3475 |
| 3476 # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...],
...] |
| 3477 cascades = [[node.operand1]] |
| 3478 final_false_result = [] |
| 3479 |
| 3480 def split_cascades(cmp_node): |
| 3481 if cmp_node.has_constant_result(): |
| 3482 if not cmp_node.constant_result: |
| 3483 # False => short-circuit |
| 3484 final_false_result.append(self._bool_node(cmp_node, False)) |
| 3485 return |
| 3486 else: |
| 3487 # True => discard and start new cascade |
| 3488 cascades.append([cmp_node.operand2]) |
| 3489 else: |
| 3490 # not constant => append to current cascade |
| 3491 cascades[-1].append(cmp_node) |
| 3492 if cmp_node.cascade: |
| 3493 split_cascades(cmp_node.cascade) |
| 3494 |
| 3495 split_cascades(node) |
| 3496 |
| 3497 cmp_nodes = [] |
| 3498 for cascade in cascades: |
| 3499 if len(cascade) < 2: |
| 3500 continue |
| 3501 cmp_node = cascade[1] |
| 3502 pcmp_node = ExprNodes.PrimaryCmpNode( |
| 3503 cmp_node.pos, |
| 3504 operand1=cascade[0], |
| 3505 operator=cmp_node.operator, |
| 3506 operand2=cmp_node.operand2, |
| 3507 constant_result=not_a_constant) |
| 3508 cmp_nodes.append(pcmp_node) |
| 3509 |
| 3510 last_cmp_node = pcmp_node |
| 3511 for cmp_node in cascade[2:]: |
| 3512 last_cmp_node.cascade = cmp_node |
| 3513 last_cmp_node = cmp_node |
| 3514 last_cmp_node.cascade = None |
| 3515 |
| 3516 if final_false_result: |
| 3517 # last cascade was constant False |
| 3518 cmp_nodes.append(final_false_result[0]) |
| 3519 elif not cmp_nodes: |
| 3520 # only constants, but no False result |
| 3521 return self._bool_node(node, True) |
| 3522 node = cmp_nodes[0] |
| 3523 if len(cmp_nodes) == 1: |
| 3524 if node.has_constant_result(): |
| 3525 return self._bool_node(node, node.constant_result) |
| 3526 else: |
| 3527 for cmp_node in cmp_nodes[1:]: |
| 3528 node = ExprNodes.BoolBinopNode( |
| 3529 node.pos, |
| 3530 operand1=node, |
| 3531 operator='and', |
| 3532 operand2=cmp_node, |
| 3533 constant_result=not_a_constant) |
| 3534 return node |
| 3535 |
| 3536 def visit_CondExprNode(self, node): |
| 3537 self._calculate_const(node) |
| 3538 if not node.test.has_constant_result(): |
| 3539 return node |
| 3540 if node.test.constant_result: |
| 3541 return node.true_val |
| 3542 else: |
| 3543 return node.false_val |
| 3544 |
| 3545 def visit_IfStatNode(self, node): |
| 3546 self.visitchildren(node) |
| 3547 # eliminate dead code based on constant condition results |
| 3548 if_clauses = [] |
| 3549 for if_clause in node.if_clauses: |
| 3550 condition = if_clause.condition |
| 3551 if condition.has_constant_result(): |
| 3552 if condition.constant_result: |
| 3553 # always true => subsequent clauses can safely be dropped |
| 3554 node.else_clause = if_clause.body |
| 3555 break |
| 3556 # else: false => drop clause |
| 3557 else: |
| 3558 # unknown result => normal runtime evaluation |
| 3559 if_clauses.append(if_clause) |
| 3560 if if_clauses: |
| 3561 node.if_clauses = if_clauses |
| 3562 return node |
| 3563 elif node.else_clause: |
| 3564 return node.else_clause |
| 3565 else: |
| 3566 return Nodes.StatListNode(node.pos, stats=[]) |
| 3567 |
| 3568 def visit_SliceIndexNode(self, node): |
| 3569 self._calculate_const(node) |
| 3570 # normalise start/stop values |
| 3571 if node.start is None or node.start.constant_result is None: |
| 3572 start = node.start = None |
| 3573 else: |
| 3574 start = node.start.constant_result |
| 3575 if node.stop is None or node.stop.constant_result is None: |
| 3576 stop = node.stop = None |
| 3577 else: |
| 3578 stop = node.stop.constant_result |
| 3579 # cut down sliced constant sequences |
| 3580 if node.constant_result is not not_a_constant: |
| 3581 base = node.base |
| 3582 if base.is_sequence_constructor and base.mult_factor is None: |
| 3583 base.args = base.args[start:stop] |
| 3584 return base |
| 3585 elif base.is_string_literal: |
| 3586 base = base.as_sliced_node(start, stop) |
| 3587 if base is not None: |
| 3588 return base |
| 3589 return node |
| 3590 |
| 3591 def visit_ComprehensionNode(self, node): |
| 3592 self.visitchildren(node) |
| 3593 if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats: |
| 3594 # loop was pruned already => transform into literal |
| 3595 if node.type is Builtin.list_type: |
| 3596 return ExprNodes.ListNode( |
| 3597 node.pos, args=[], constant_result=[]) |
| 3598 elif node.type is Builtin.set_type: |
| 3599 return ExprNodes.SetNode( |
| 3600 node.pos, args=[], constant_result=set()) |
| 3601 elif node.type is Builtin.dict_type: |
| 3602 return ExprNodes.DictNode( |
| 3603 node.pos, key_value_pairs=[], constant_result={}) |
| 3604 return node |
| 3605 |
| 3606 def visit_ForInStatNode(self, node): |
| 3607 self.visitchildren(node) |
| 3608 sequence = node.iterator.sequence |
| 3609 if isinstance(sequence, ExprNodes.SequenceNode): |
| 3610 if not sequence.args: |
| 3611 if node.else_clause: |
| 3612 return node.else_clause |
| 3613 else: |
| 3614 # don't break list comprehensions |
| 3615 return Nodes.StatListNode(node.pos, stats=[]) |
| 3616 # iterating over a list literal? => tuples are more efficient |
| 3617 if isinstance(sequence, ExprNodes.ListNode): |
| 3618 node.iterator.sequence = sequence.as_tuple() |
| 3619 return node |
| 3620 |
| 3621 def visit_WhileStatNode(self, node): |
| 3622 self.visitchildren(node) |
| 3623 if node.condition and node.condition.has_constant_result(): |
| 3624 if node.condition.constant_result: |
| 3625 node.condition = None |
| 3626 node.else_clause = None |
| 3627 else: |
| 3628 return node.else_clause |
| 3629 return node |
| 3630 |
| 3631 def visit_ExprStatNode(self, node): |
| 3632 self.visitchildren(node) |
| 3633 if not isinstance(node.expr, ExprNodes.ExprNode): |
| 3634 # ParallelRangeTransform does this ... |
| 3635 return node |
| 3636 # drop unused constant expressions |
| 3637 if node.expr.has_constant_result(): |
| 3638 return None |
| 3639 return node |
| 3640 |
| 3641 # in the future, other nodes can have their own handler method here |
| 3642 # that can replace them with a constant result node |
| 3643 |
| 3644 visit_Node = Visitor.VisitorTransform.recurse_to_children |
| 3645 |
| 3646 |
| 3647 class FinalOptimizePhase(Visitor.CythonTransform): |
| 3648 """ |
| 3649 This visitor handles several commuting optimizations, and is run |
| 3650 just before the C code generation phase. |
| 3651 |
| 3652 The optimizations currently implemented in this class are: |
| 3653 - eliminate None assignment and refcounting for first assignment. |
| 3654 - isinstance -> typecheck for cdef types |
| 3655 - eliminate checks for None and/or types that became redundant after tre
e changes |
| 3656 """ |
| 3657 def visit_SingleAssignmentNode(self, node): |
| 3658 """Avoid redundant initialisation of local variables before their |
| 3659 first assignment. |
| 3660 """ |
| 3661 self.visitchildren(node) |
| 3662 if node.first: |
| 3663 lhs = node.lhs |
| 3664 lhs.lhs_of_first_assignment = True |
| 3665 return node |
| 3666 |
| 3667 def visit_SimpleCallNode(self, node): |
| 3668 """Replace generic calls to isinstance(x, type) by a more efficient |
| 3669 type check. |
| 3670 """ |
| 3671 self.visitchildren(node) |
| 3672 if node.function.type.is_cfunction and isinstance(node.function, ExprNod
es.NameNode): |
| 3673 if node.function.name == 'isinstance' and len(node.args) == 2: |
| 3674 type_arg = node.args[1] |
| 3675 if type_arg.type.is_builtin_type and type_arg.type.name == 'type
': |
| 3676 cython_scope = self.context.cython_scope |
| 3677 node.function.entry = cython_scope.lookup('PyObject_TypeChec
k') |
| 3678 node.function.type = node.function.entry.type |
| 3679 PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('P
yTypeObject').type) |
| 3680 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObject
Ptr) |
| 3681 return node |
| 3682 |
| 3683 def visit_PyTypeTestNode(self, node): |
| 3684 """Remove tests for alternatively allowed None values from |
| 3685 type tests when we know that the argument cannot be None |
| 3686 anyway. |
| 3687 """ |
| 3688 self.visitchildren(node) |
| 3689 if not node.notnone: |
| 3690 if not node.arg.may_be_none(): |
| 3691 node.notnone = True |
| 3692 return node |
| 3693 |
| 3694 def visit_NoneCheckNode(self, node): |
| 3695 """Remove None checks from expressions that definitely do not |
| 3696 carry a None value. |
| 3697 """ |
| 3698 self.visitchildren(node) |
| 3699 if not node.arg.may_be_none(): |
| 3700 return node.arg |
| 3701 return node |
| 3702 |
| 3703 class ConsolidateOverflowCheck(Visitor.CythonTransform): |
| 3704 """ |
| 3705 This class facilitates the sharing of overflow checking among all nodes |
| 3706 of a nested arithmetic expression. For example, given the expression |
| 3707 a*b + c, where a, b, and x are all possibly overflowing ints, the entire |
| 3708 sequence will be evaluated and the overflow bit checked only at the end. |
| 3709 """ |
| 3710 overflow_bit_node = None |
| 3711 |
| 3712 def visit_Node(self, node): |
| 3713 if self.overflow_bit_node is not None: |
| 3714 saved = self.overflow_bit_node |
| 3715 self.overflow_bit_node = None |
| 3716 self.visitchildren(node) |
| 3717 self.overflow_bit_node = saved |
| 3718 else: |
| 3719 self.visitchildren(node) |
| 3720 return node |
| 3721 |
| 3722 def visit_NumBinopNode(self, node): |
| 3723 if node.overflow_check and node.overflow_fold: |
| 3724 top_level_overflow = self.overflow_bit_node is None |
| 3725 if top_level_overflow: |
| 3726 self.overflow_bit_node = node |
| 3727 else: |
| 3728 node.overflow_bit_node = self.overflow_bit_node |
| 3729 node.overflow_check = False |
| 3730 self.visitchildren(node) |
| 3731 if top_level_overflow: |
| 3732 self.overflow_bit_node = None |
| 3733 else: |
| 3734 self.visitchildren(node) |
| 3735 return node |
OLD | NEW |