OLD | NEW |
(Empty) | |
| 1 import copy |
| 2 |
| 3 from Cython.Compiler import (ExprNodes, PyrexTypes, MemoryView, |
| 4 ParseTreeTransforms, StringEncoding, |
| 5 Errors) |
| 6 from Cython.Compiler.ExprNodes import CloneNode, ProxyNode, TupleNode |
| 7 from Cython.Compiler.Nodes import (FuncDefNode, CFuncDefNode, StatListNode, |
| 8 DefNode) |
| 9 |
| 10 class FusedCFuncDefNode(StatListNode): |
| 11 """ |
| 12 This node replaces a function with fused arguments. It deep-copies the |
| 13 function for every permutation of fused types, and allocates a new local |
| 14 scope for it. It keeps track of the original function in self.node, and |
| 15 the entry of the original function in the symbol table is given the |
| 16 'fused_cfunction' attribute which points back to us. |
| 17 Then when a function lookup occurs (to e.g. call it), the call can be |
| 18 dispatched to the right function. |
| 19 |
| 20 node FuncDefNode the original function |
| 21 nodes [FuncDefNode] list of copies of node with different specific types |
| 22 py_func DefNode the fused python function subscriptable from |
| 23 Python space |
| 24 __signatures__ A DictNode mapping signature specialization strings |
| 25 to PyCFunction nodes |
| 26 resulting_fused_function PyCFunction for the fused DefNode that delegates |
| 27 to specializations |
| 28 fused_func_assignment Assignment of the fused function to the function nam
e |
| 29 defaults_tuple TupleNode of defaults (letting PyCFunctionNode build |
| 30 defaults would result in many different tuples) |
| 31 specialized_pycfuncs List of synthesized pycfunction nodes for the |
| 32 specializations |
| 33 code_object CodeObjectNode shared by all specializations and the |
| 34 fused function |
| 35 |
| 36 fused_compound_types All fused (compound) types (e.g. floating[:]) |
| 37 """ |
| 38 |
| 39 __signatures__ = None |
| 40 resulting_fused_function = None |
| 41 fused_func_assignment = None |
| 42 defaults_tuple = None |
| 43 decorators = None |
| 44 |
| 45 child_attrs = StatListNode.child_attrs + [ |
| 46 '__signatures__', 'resulting_fused_function', 'fused_func_assignment'] |
| 47 |
| 48 def __init__(self, node, env): |
| 49 super(FusedCFuncDefNode, self).__init__(node.pos) |
| 50 |
| 51 self.nodes = [] |
| 52 self.node = node |
| 53 |
| 54 is_def = isinstance(self.node, DefNode) |
| 55 if is_def: |
| 56 # self.node.decorators = [] |
| 57 self.copy_def(env) |
| 58 else: |
| 59 self.copy_cdef(env) |
| 60 |
| 61 # Perform some sanity checks. If anything fails, it's a bug |
| 62 for n in self.nodes: |
| 63 assert not n.entry.type.is_fused |
| 64 assert not n.local_scope.return_type.is_fused |
| 65 if node.return_type.is_fused: |
| 66 assert not n.return_type.is_fused |
| 67 |
| 68 if not is_def and n.cfunc_declarator.optional_arg_count: |
| 69 assert n.type.op_arg_struct |
| 70 |
| 71 node.entry.fused_cfunction = self |
| 72 # Copy the nodes as AnalyseDeclarationsTransform will prepend |
| 73 # self.py_func to self.stats, as we only want specialized |
| 74 # CFuncDefNodes in self.nodes |
| 75 self.stats = self.nodes[:] |
| 76 |
| 77 def copy_def(self, env): |
| 78 """ |
| 79 Create a copy of the original def or lambda function for specialized |
| 80 versions. |
| 81 """ |
| 82 fused_compound_types = PyrexTypes.unique( |
| 83 [arg.type for arg in self.node.args if arg.type.is_fused]) |
| 84 permutations = PyrexTypes.get_all_specialized_permutations(fused_compoun
d_types) |
| 85 |
| 86 self.fused_compound_types = fused_compound_types |
| 87 |
| 88 if self.node.entry in env.pyfunc_entries: |
| 89 env.pyfunc_entries.remove(self.node.entry) |
| 90 |
| 91 for cname, fused_to_specific in permutations: |
| 92 copied_node = copy.deepcopy(self.node) |
| 93 |
| 94 self._specialize_function_args(copied_node.args, fused_to_specific) |
| 95 copied_node.return_type = self.node.return_type.specialize( |
| 96 fused_to_specific) |
| 97 |
| 98 copied_node.analyse_declarations(env) |
| 99 # copied_node.is_staticmethod = self.node.is_staticmethod |
| 100 # copied_node.is_classmethod = self.node.is_classmethod |
| 101 self.create_new_local_scope(copied_node, env, fused_to_specific) |
| 102 self.specialize_copied_def(copied_node, cname, self.node.entry, |
| 103 fused_to_specific, fused_compound_types) |
| 104 |
| 105 PyrexTypes.specialize_entry(copied_node.entry, cname) |
| 106 copied_node.entry.used = True |
| 107 env.entries[copied_node.entry.name] = copied_node.entry |
| 108 |
| 109 if not self.replace_fused_typechecks(copied_node): |
| 110 break |
| 111 |
| 112 self.orig_py_func = self.node |
| 113 self.py_func = self.make_fused_cpdef(self.node, env, is_def=True) |
| 114 |
| 115 def copy_cdef(self, env): |
| 116 """ |
| 117 Create a copy of the original c(p)def function for all specialized |
| 118 versions. |
| 119 """ |
| 120 permutations = self.node.type.get_all_specialized_permutations() |
| 121 # print 'Node %s has %d specializations:' % (self.node.entry.name, |
| 122 # len(permutations)) |
| 123 # import pprint; pprint.pprint([d for cname, d in permutations]) |
| 124 |
| 125 if self.node.entry in env.cfunc_entries: |
| 126 env.cfunc_entries.remove(self.node.entry) |
| 127 |
| 128 # Prevent copying of the python function |
| 129 self.orig_py_func = orig_py_func = self.node.py_func |
| 130 self.node.py_func = None |
| 131 if orig_py_func: |
| 132 env.pyfunc_entries.remove(orig_py_func.entry) |
| 133 |
| 134 fused_types = self.node.type.get_fused_types() |
| 135 self.fused_compound_types = fused_types |
| 136 |
| 137 for cname, fused_to_specific in permutations: |
| 138 copied_node = copy.deepcopy(self.node) |
| 139 |
| 140 # Make the types in our CFuncType specific |
| 141 type = copied_node.type.specialize(fused_to_specific) |
| 142 entry = copied_node.entry |
| 143 |
| 144 copied_node.type = type |
| 145 entry.type, type.entry = type, entry |
| 146 |
| 147 entry.used = (entry.used or |
| 148 self.node.entry.defined_in_pxd or |
| 149 env.is_c_class_scope or |
| 150 entry.is_cmethod) |
| 151 |
| 152 if self.node.cfunc_declarator.optional_arg_count: |
| 153 self.node.cfunc_declarator.declare_optional_arg_struct( |
| 154 type, env, fused_cname=cname) |
| 155 |
| 156 copied_node.return_type = type.return_type |
| 157 self.create_new_local_scope(copied_node, env, fused_to_specific) |
| 158 |
| 159 # Make the argument types in the CFuncDeclarator specific |
| 160 self._specialize_function_args(copied_node.cfunc_declarator.args, |
| 161 fused_to_specific) |
| 162 |
| 163 type.specialize_entry(entry, cname) |
| 164 env.cfunc_entries.append(entry) |
| 165 |
| 166 # If a cpdef, declare all specialized cpdefs (this |
| 167 # also calls analyse_declarations) |
| 168 copied_node.declare_cpdef_wrapper(env) |
| 169 if copied_node.py_func: |
| 170 env.pyfunc_entries.remove(copied_node.py_func.entry) |
| 171 |
| 172 self.specialize_copied_def( |
| 173 copied_node.py_func, cname, self.node.entry.as_variable, |
| 174 fused_to_specific, fused_types) |
| 175 |
| 176 if not self.replace_fused_typechecks(copied_node): |
| 177 break |
| 178 |
| 179 if orig_py_func: |
| 180 self.py_func = self.make_fused_cpdef(orig_py_func, env, |
| 181 is_def=False) |
| 182 else: |
| 183 self.py_func = orig_py_func |
| 184 |
| 185 def _specialize_function_args(self, args, fused_to_specific): |
| 186 for arg in args: |
| 187 if arg.type.is_fused: |
| 188 arg.type = arg.type.specialize(fused_to_specific) |
| 189 if arg.type.is_memoryviewslice: |
| 190 MemoryView.validate_memslice_dtype(arg.pos, arg.type.dtype) |
| 191 |
| 192 def create_new_local_scope(self, node, env, f2s): |
| 193 """ |
| 194 Create a new local scope for the copied node and append it to |
| 195 self.nodes. A new local scope is needed because the arguments with the |
| 196 fused types are aready in the local scope, and we need the specialized |
| 197 entries created after analyse_declarations on each specialized version |
| 198 of the (CFunc)DefNode. |
| 199 f2s is a dict mapping each fused type to its specialized version |
| 200 """ |
| 201 node.create_local_scope(env) |
| 202 node.local_scope.fused_to_specific = f2s |
| 203 |
| 204 # This is copied from the original function, set it to false to |
| 205 # stop recursion |
| 206 node.has_fused_arguments = False |
| 207 self.nodes.append(node) |
| 208 |
| 209 def specialize_copied_def(self, node, cname, py_entry, f2s, fused_types): |
| 210 """Specialize the copy of a DefNode given the copied node, |
| 211 the specialization cname and the original DefNode entry""" |
| 212 type_strings = [ |
| 213 PyrexTypes.specialization_signature_string(fused_type, f2s) |
| 214 for fused_type in fused_types |
| 215 ] |
| 216 |
| 217 node.specialized_signature_string = '|'.join(type_strings) |
| 218 |
| 219 node.entry.pymethdef_cname = PyrexTypes.get_fused_cname( |
| 220 cname, node.entry.pymethdef_cname) |
| 221 node.entry.doc = py_entry.doc |
| 222 node.entry.doc_cname = py_entry.doc_cname |
| 223 |
| 224 def replace_fused_typechecks(self, copied_node): |
| 225 """ |
| 226 Branch-prune fused type checks like |
| 227 |
| 228 if fused_t is int: |
| 229 ... |
| 230 |
| 231 Returns whether an error was issued and whether we should stop in |
| 232 in order to prevent a flood of errors. |
| 233 """ |
| 234 num_errors = Errors.num_errors |
| 235 transform = ParseTreeTransforms.ReplaceFusedTypeChecks( |
| 236 copied_node.local_scope) |
| 237 transform(copied_node) |
| 238 |
| 239 if Errors.num_errors > num_errors: |
| 240 return False |
| 241 |
| 242 return True |
| 243 |
| 244 def _fused_instance_checks(self, normal_types, pyx_code, env): |
| 245 """ |
| 246 Genereate Cython code for instance checks, matching an object to |
| 247 specialized types. |
| 248 """ |
| 249 if_ = 'if' |
| 250 for specialized_type in normal_types: |
| 251 # all_numeric = all_numeric and specialized_type.is_numeric |
| 252 py_type_name = specialized_type.py_type_name() |
| 253 specialized_type_name = specialized_type.specialization_string |
| 254 pyx_code.context.update(locals()) |
| 255 pyx_code.put_chunk( |
| 256 u""" |
| 257 {{if_}} isinstance(arg, {{py_type_name}}): |
| 258 dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}' |
| 259 """) |
| 260 if_ = 'elif' |
| 261 |
| 262 if not normal_types: |
| 263 # we need an 'if' to match the following 'else' |
| 264 pyx_code.putln("if 0: pass") |
| 265 |
| 266 def _dtype_name(self, dtype): |
| 267 if dtype.is_typedef: |
| 268 return '___pyx_%s' % dtype |
| 269 return str(dtype).replace(' ', '_') |
| 270 |
| 271 def _dtype_type(self, dtype): |
| 272 if dtype.is_typedef: |
| 273 return self._dtype_name(dtype) |
| 274 return str(dtype) |
| 275 |
| 276 def _sizeof_dtype(self, dtype): |
| 277 if dtype.is_pyobject: |
| 278 return 'sizeof(void *)' |
| 279 else: |
| 280 return "sizeof(%s)" % self._dtype_type(dtype) |
| 281 |
| 282 def _buffer_check_numpy_dtype_setup_cases(self, pyx_code): |
| 283 "Setup some common cases to match dtypes against specializations" |
| 284 if pyx_code.indenter("if dtype.kind in ('i', 'u'):"): |
| 285 pyx_code.putln("pass") |
| 286 pyx_code.named_insertion_point("dtype_int") |
| 287 pyx_code.dedent() |
| 288 |
| 289 if pyx_code.indenter("elif dtype.kind == 'f':"): |
| 290 pyx_code.putln("pass") |
| 291 pyx_code.named_insertion_point("dtype_float") |
| 292 pyx_code.dedent() |
| 293 |
| 294 if pyx_code.indenter("elif dtype.kind == 'c':"): |
| 295 pyx_code.putln("pass") |
| 296 pyx_code.named_insertion_point("dtype_complex") |
| 297 pyx_code.dedent() |
| 298 |
| 299 if pyx_code.indenter("elif dtype.kind == 'O':"): |
| 300 pyx_code.putln("pass") |
| 301 pyx_code.named_insertion_point("dtype_object") |
| 302 pyx_code.dedent() |
| 303 |
| 304 match = "dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'" |
| 305 no_match = "dest_sig[{{dest_sig_idx}}] = None" |
| 306 def _buffer_check_numpy_dtype(self, pyx_code, specialized_buffer_types): |
| 307 """ |
| 308 Match a numpy dtype object to the individual specializations. |
| 309 """ |
| 310 self._buffer_check_numpy_dtype_setup_cases(pyx_code) |
| 311 |
| 312 for specialized_type in specialized_buffer_types: |
| 313 dtype = specialized_type.dtype |
| 314 pyx_code.context.update( |
| 315 itemsize_match=self._sizeof_dtype(dtype) + " == itemsize", |
| 316 signed_match="not (%s_is_signed ^ dtype_signed)" % self._dtype_n
ame(dtype), |
| 317 dtype=dtype, |
| 318 specialized_type_name=specialized_type.specialization_string) |
| 319 |
| 320 dtypes = [ |
| 321 (dtype.is_int, pyx_code.dtype_int), |
| 322 (dtype.is_float, pyx_code.dtype_float), |
| 323 (dtype.is_complex, pyx_code.dtype_complex) |
| 324 ] |
| 325 |
| 326 for dtype_category, codewriter in dtypes: |
| 327 if dtype_category: |
| 328 cond = '{{itemsize_match}} and arg.ndim == %d' % ( |
| 329 specialized_type.ndim,) |
| 330 if dtype.is_int: |
| 331 cond += ' and {{signed_match}}' |
| 332 |
| 333 if codewriter.indenter("if %s:" % cond): |
| 334 # codewriter.putln("print 'buffer match found based on n
umpy dtype'") |
| 335 codewriter.putln(self.match) |
| 336 codewriter.putln("break") |
| 337 codewriter.dedent() |
| 338 |
| 339 def _buffer_parse_format_string_check(self, pyx_code, decl_code, |
| 340 specialized_type, env): |
| 341 """ |
| 342 For each specialized type, try to coerce the object to a memoryview |
| 343 slice of that type. This means obtaining a buffer and parsing the |
| 344 format string. |
| 345 TODO: separate buffer acquisition from format parsing |
| 346 """ |
| 347 dtype = specialized_type.dtype |
| 348 if specialized_type.is_buffer: |
| 349 axes = [('direct', 'strided')] * specialized_type.ndim |
| 350 else: |
| 351 axes = specialized_type.axes |
| 352 |
| 353 memslice_type = PyrexTypes.MemoryViewSliceType(dtype, axes) |
| 354 memslice_type.create_from_py_utility_code(env) |
| 355 pyx_code.context.update( |
| 356 coerce_from_py_func=memslice_type.from_py_function, |
| 357 dtype=dtype) |
| 358 decl_code.putln( |
| 359 "{{memviewslice_cname}} {{coerce_from_py_func}}(object)") |
| 360 |
| 361 pyx_code.context.update( |
| 362 specialized_type_name=specialized_type.specialization_string, |
| 363 sizeof_dtype=self._sizeof_dtype(dtype)) |
| 364 |
| 365 pyx_code.put_chunk( |
| 366 u""" |
| 367 # try {{dtype}} |
| 368 if itemsize == -1 or itemsize == {{sizeof_dtype}}: |
| 369 memslice = {{coerce_from_py_func}}(arg) |
| 370 if memslice.memview: |
| 371 __PYX_XDEC_MEMVIEW(&memslice, 1) |
| 372 # print 'found a match for the buffer through format par
sing' |
| 373 %s |
| 374 break |
| 375 else: |
| 376 __pyx_PyErr_Clear() |
| 377 """ % self.match) |
| 378 |
| 379 def _buffer_checks(self, buffer_types, pyx_code, decl_code, env): |
| 380 """ |
| 381 Generate Cython code to match objects to buffer specializations. |
| 382 First try to get a numpy dtype object and match it against the individua
l |
| 383 specializations. If that fails, try naively to coerce the object |
| 384 to each specialization, which obtains the buffer each time and tries |
| 385 to match the format string. |
| 386 """ |
| 387 from Cython.Compiler import ExprNodes |
| 388 if buffer_types: |
| 389 if pyx_code.indenter(u"else:"): |
| 390 # The first thing to find a match in this loop breaks out of the
loop |
| 391 if pyx_code.indenter(u"while 1:"): |
| 392 pyx_code.put_chunk( |
| 393 u""" |
| 394 if numpy is not None: |
| 395 if isinstance(arg, numpy.ndarray): |
| 396 dtype = arg.dtype |
| 397 elif (__pyx_memoryview_check(arg) and |
| 398 isinstance(arg.base, numpy.ndarray)): |
| 399 dtype = arg.base.dtype |
| 400 else: |
| 401 dtype = None |
| 402 |
| 403 itemsize = -1 |
| 404 if dtype is not None: |
| 405 itemsize = dtype.itemsize |
| 406 kind = ord(dtype.kind) |
| 407 dtype_signed = kind == ord('i') |
| 408 """) |
| 409 pyx_code.indent(2) |
| 410 pyx_code.named_insertion_point("numpy_dtype_checks") |
| 411 self._buffer_check_numpy_dtype(pyx_code, buffer_types) |
| 412 pyx_code.dedent(2) |
| 413 |
| 414 for specialized_type in buffer_types: |
| 415 self._buffer_parse_format_string_check( |
| 416 pyx_code, decl_code, specialized_type, env) |
| 417 |
| 418 pyx_code.putln(self.no_match) |
| 419 pyx_code.putln("break") |
| 420 pyx_code.dedent() |
| 421 |
| 422 pyx_code.dedent() |
| 423 else: |
| 424 pyx_code.putln("else: %s" % self.no_match) |
| 425 |
| 426 def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types): |
| 427 """ |
| 428 If we have any buffer specializations, write out some variable |
| 429 declarations and imports. |
| 430 """ |
| 431 decl_code.put_chunk( |
| 432 u""" |
| 433 ctypedef struct {{memviewslice_cname}}: |
| 434 void *memview |
| 435 |
| 436 void __PYX_XDEC_MEMVIEW({{memviewslice_cname}} *, int have_gil) |
| 437 bint __pyx_memoryview_check(object) |
| 438 """) |
| 439 |
| 440 pyx_code.local_variable_declarations.put_chunk( |
| 441 u""" |
| 442 cdef {{memviewslice_cname}} memslice |
| 443 cdef Py_ssize_t itemsize |
| 444 cdef bint dtype_signed |
| 445 cdef char kind |
| 446 |
| 447 itemsize = -1 |
| 448 """) |
| 449 |
| 450 pyx_code.imports.put_chunk( |
| 451 u""" |
| 452 try: |
| 453 import numpy |
| 454 except ImportError: |
| 455 numpy = None |
| 456 """) |
| 457 |
| 458 seen_int_dtypes = set() |
| 459 for buffer_type in all_buffer_types: |
| 460 dtype = buffer_type.dtype |
| 461 if dtype.is_typedef: |
| 462 #decl_code.putln("ctypedef %s %s" % (dtype.resolve(), |
| 463 # self._dtype_name(dtype))) |
| 464 decl_code.putln('ctypedef %s %s "%s"' % (dtype.resolve(), |
| 465 self._dtype_name(dtype)
, |
| 466 dtype.declaration_code(
""))) |
| 467 |
| 468 if buffer_type.dtype.is_int: |
| 469 if str(dtype) not in seen_int_dtypes: |
| 470 seen_int_dtypes.add(str(dtype)) |
| 471 pyx_code.context.update(dtype_name=self._dtype_name(dtype), |
| 472 dtype_type=self._dtype_type(dtype)) |
| 473 pyx_code.local_variable_declarations.put_chunk( |
| 474 u""" |
| 475 cdef bint {{dtype_name}}_is_signed |
| 476 {{dtype_name}}_is_signed = <{{dtype_type}}> -1 < 0 |
| 477 """) |
| 478 |
| 479 def _split_fused_types(self, arg): |
| 480 """ |
| 481 Specialize fused types and split into normal types and buffer types. |
| 482 """ |
| 483 specialized_types = PyrexTypes.get_specialized_types(arg.type) |
| 484 # Prefer long over int, etc |
| 485 # specialized_types.sort() |
| 486 seen_py_type_names = set() |
| 487 normal_types, buffer_types = [], [] |
| 488 for specialized_type in specialized_types: |
| 489 py_type_name = specialized_type.py_type_name() |
| 490 if py_type_name: |
| 491 if py_type_name in seen_py_type_names: |
| 492 continue |
| 493 seen_py_type_names.add(py_type_name) |
| 494 normal_types.append(specialized_type) |
| 495 elif specialized_type.is_buffer or specialized_type.is_memoryviewsli
ce: |
| 496 buffer_types.append(specialized_type) |
| 497 |
| 498 return normal_types, buffer_types |
| 499 |
| 500 def _unpack_argument(self, pyx_code): |
| 501 pyx_code.put_chunk( |
| 502 u""" |
| 503 # PROCESSING ARGUMENT {{arg_tuple_idx}} |
| 504 if {{arg_tuple_idx}} < len(args): |
| 505 arg = args[{{arg_tuple_idx}}] |
| 506 elif '{{arg.name}}' in kwargs: |
| 507 arg = kwargs['{{arg.name}}'] |
| 508 else: |
| 509 {{if arg.default:}} |
| 510 arg = defaults[{{default_idx}}] |
| 511 {{else}} |
| 512 raise TypeError("Expected at least %d arguments" % len(args)
) |
| 513 {{endif}} |
| 514 """) |
| 515 |
| 516 def make_fused_cpdef(self, orig_py_func, env, is_def): |
| 517 """ |
| 518 This creates the function that is indexable from Python and does |
| 519 runtime dispatch based on the argument types. The function gets the |
| 520 arg tuple and kwargs dict (or None) and the defaults tuple |
| 521 as arguments from the Binding Fused Function's tp_call. |
| 522 """ |
| 523 from Cython.Compiler import TreeFragment, Code, MemoryView, UtilityCode |
| 524 |
| 525 # { (arg_pos, FusedType) : specialized_type } |
| 526 seen_fused_types = set() |
| 527 |
| 528 context = { |
| 529 'memviewslice_cname': MemoryView.memviewslice_cname, |
| 530 'func_args': self.node.args, |
| 531 'n_fused': len([arg for arg in self.node.args]), |
| 532 'name': orig_py_func.entry.name, |
| 533 } |
| 534 |
| 535 pyx_code = Code.PyxCodeWriter(context=context) |
| 536 decl_code = Code.PyxCodeWriter(context=context) |
| 537 decl_code.put_chunk( |
| 538 u""" |
| 539 cdef extern from *: |
| 540 void __pyx_PyErr_Clear "PyErr_Clear" () |
| 541 """) |
| 542 decl_code.indent() |
| 543 |
| 544 pyx_code.put_chunk( |
| 545 u""" |
| 546 def __pyx_fused_cpdef(signatures, args, kwargs, defaults): |
| 547 dest_sig = [{{for _ in range(n_fused)}}None,{{endfor}}] |
| 548 |
| 549 if kwargs is None: |
| 550 kwargs = {} |
| 551 |
| 552 cdef Py_ssize_t i |
| 553 |
| 554 # instance check body |
| 555 """) |
| 556 pyx_code.indent() # indent following code to function body |
| 557 pyx_code.named_insertion_point("imports") |
| 558 pyx_code.named_insertion_point("local_variable_declarations") |
| 559 |
| 560 fused_index = 0 |
| 561 default_idx = 0 |
| 562 all_buffer_types = set() |
| 563 for i, arg in enumerate(self.node.args): |
| 564 if arg.type.is_fused and arg.type not in seen_fused_types: |
| 565 seen_fused_types.add(arg.type) |
| 566 |
| 567 context.update( |
| 568 arg_tuple_idx=i, |
| 569 arg=arg, |
| 570 dest_sig_idx=fused_index, |
| 571 default_idx=default_idx, |
| 572 ) |
| 573 |
| 574 normal_types, buffer_types = self._split_fused_types(arg) |
| 575 self._unpack_argument(pyx_code) |
| 576 self._fused_instance_checks(normal_types, pyx_code, env) |
| 577 self._buffer_checks(buffer_types, pyx_code, decl_code, env) |
| 578 fused_index += 1 |
| 579 |
| 580 all_buffer_types.update(buffer_types) |
| 581 |
| 582 if arg.default: |
| 583 default_idx += 1 |
| 584 |
| 585 if all_buffer_types: |
| 586 self._buffer_declarations(pyx_code, decl_code, all_buffer_types) |
| 587 env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportE
xport.c")) |
| 588 |
| 589 pyx_code.put_chunk( |
| 590 u""" |
| 591 candidates = [] |
| 592 for sig in signatures: |
| 593 match_found = False |
| 594 for src_type, dst_type in zip(sig.strip('()').split('|'), de
st_sig): |
| 595 if dst_type is not None: |
| 596 if src_type == dst_type: |
| 597 match_found = True |
| 598 else: |
| 599 match_found = False |
| 600 break |
| 601 |
| 602 if match_found: |
| 603 candidates.append(sig) |
| 604 |
| 605 if not candidates: |
| 606 raise TypeError("No matching signature found") |
| 607 elif len(candidates) > 1: |
| 608 raise TypeError("Function call with ambiguous argument types
") |
| 609 else: |
| 610 return signatures[candidates[0]] |
| 611 """) |
| 612 |
| 613 fragment_code = pyx_code.getvalue() |
| 614 # print decl_code.getvalue() |
| 615 # print fragment_code |
| 616 fragment = TreeFragment.TreeFragment(fragment_code, level='module') |
| 617 ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root) |
| 618 UtilityCode.declare_declarations_in_scope(decl_code.getvalue(), |
| 619 env.global_scope()) |
| 620 ast.scope = env |
| 621 ast.analyse_declarations(env) |
| 622 py_func = ast.stats[-1] # the DefNode |
| 623 self.fragment_scope = ast.scope |
| 624 |
| 625 if isinstance(self.node, DefNode): |
| 626 py_func.specialized_cpdefs = self.nodes[:] |
| 627 else: |
| 628 py_func.specialized_cpdefs = [n.py_func for n in self.nodes] |
| 629 |
| 630 return py_func |
| 631 |
| 632 def update_fused_defnode_entry(self, env): |
| 633 copy_attributes = ( |
| 634 'name', 'pos', 'cname', 'func_cname', 'pyfunc_cname', |
| 635 'pymethdef_cname', 'doc', 'doc_cname', 'is_member', |
| 636 'scope' |
| 637 ) |
| 638 |
| 639 entry = self.py_func.entry |
| 640 |
| 641 for attr in copy_attributes: |
| 642 setattr(entry, attr, |
| 643 getattr(self.orig_py_func.entry, attr)) |
| 644 |
| 645 self.py_func.name = self.orig_py_func.name |
| 646 self.py_func.doc = self.orig_py_func.doc |
| 647 |
| 648 env.entries.pop('__pyx_fused_cpdef', None) |
| 649 if isinstance(self.node, DefNode): |
| 650 env.entries[entry.name] = entry |
| 651 else: |
| 652 env.entries[entry.name].as_variable = entry |
| 653 |
| 654 env.pyfunc_entries.append(entry) |
| 655 |
| 656 self.py_func.entry.fused_cfunction = self |
| 657 for node in self.nodes: |
| 658 if isinstance(self.node, DefNode): |
| 659 node.fused_py_func = self.py_func |
| 660 else: |
| 661 node.py_func.fused_py_func = self.py_func |
| 662 node.entry.as_variable = entry |
| 663 |
| 664 self.synthesize_defnodes() |
| 665 self.stats.append(self.__signatures__) |
| 666 |
| 667 def analyse_expressions(self, env): |
| 668 """ |
| 669 Analyse the expressions. Take care to only evaluate default arguments |
| 670 once and clone the result for all specializations |
| 671 """ |
| 672 for fused_compound_type in self.fused_compound_types: |
| 673 for fused_type in fused_compound_type.get_fused_types(): |
| 674 for specialization_type in fused_type.types: |
| 675 if specialization_type.is_complex: |
| 676 specialization_type.create_declaration_utility_code(env) |
| 677 |
| 678 if self.py_func: |
| 679 self.__signatures__ = self.__signatures__.analyse_expressions(env) |
| 680 self.py_func = self.py_func.analyse_expressions(env) |
| 681 self.resulting_fused_function = self.resulting_fused_function.analys
e_expressions(env) |
| 682 self.fused_func_assignment = self.fused_func_assignment.analyse_expr
essions(env) |
| 683 |
| 684 self.defaults = defaults = [] |
| 685 |
| 686 for arg in self.node.args: |
| 687 if arg.default: |
| 688 arg.default = arg.default.analyse_expressions(env) |
| 689 defaults.append(ProxyNode(arg.default)) |
| 690 else: |
| 691 defaults.append(None) |
| 692 |
| 693 for i, stat in enumerate(self.stats): |
| 694 stat = self.stats[i] = stat.analyse_expressions(env) |
| 695 if isinstance(stat, FuncDefNode): |
| 696 for arg, default in zip(stat.args, defaults): |
| 697 if default is not None: |
| 698 arg.default = CloneNode(default).coerce_to(arg.type, env
) |
| 699 |
| 700 if self.py_func: |
| 701 args = [CloneNode(default) for default in defaults if default] |
| 702 self.defaults_tuple = TupleNode(self.pos, args=args) |
| 703 self.defaults_tuple = self.defaults_tuple.analyse_types(env, skip_ch
ildren=True) |
| 704 self.defaults_tuple = ProxyNode(self.defaults_tuple) |
| 705 self.code_object = ProxyNode(self.specialized_pycfuncs[0].code_objec
t) |
| 706 |
| 707 fused_func = self.resulting_fused_function.arg |
| 708 fused_func.defaults_tuple = CloneNode(self.defaults_tuple) |
| 709 fused_func.code_object = CloneNode(self.code_object) |
| 710 |
| 711 for i, pycfunc in enumerate(self.specialized_pycfuncs): |
| 712 pycfunc.code_object = CloneNode(self.code_object) |
| 713 pycfunc = self.specialized_pycfuncs[i] = pycfunc.analyse_types(e
nv) |
| 714 pycfunc.defaults_tuple = CloneNode(self.defaults_tuple) |
| 715 return self |
| 716 |
| 717 def synthesize_defnodes(self): |
| 718 """ |
| 719 Create the __signatures__ dict of PyCFunctionNode specializations. |
| 720 """ |
| 721 if isinstance(self.nodes[0], CFuncDefNode): |
| 722 nodes = [node.py_func for node in self.nodes] |
| 723 else: |
| 724 nodes = self.nodes |
| 725 |
| 726 signatures = [ |
| 727 StringEncoding.EncodedString(node.specialized_signature_string) |
| 728 for node in nodes] |
| 729 keys = [ExprNodes.StringNode(node.pos, value=sig) |
| 730 for node, sig in zip(nodes, signatures)] |
| 731 values = [ExprNodes.PyCFunctionNode.from_defnode(node, True) |
| 732 for node in nodes] |
| 733 self.__signatures__ = ExprNodes.DictNode.from_pairs(self.pos, |
| 734 zip(keys, values)) |
| 735 |
| 736 self.specialized_pycfuncs = values |
| 737 for pycfuncnode in values: |
| 738 pycfuncnode.is_specialization = True |
| 739 |
| 740 def generate_function_definitions(self, env, code): |
| 741 if self.py_func: |
| 742 self.py_func.pymethdef_required = True |
| 743 self.fused_func_assignment.generate_function_definitions(env, code) |
| 744 |
| 745 for stat in self.stats: |
| 746 if isinstance(stat, FuncDefNode) and stat.entry.used: |
| 747 code.mark_pos(stat.pos) |
| 748 stat.generate_function_definitions(env, code) |
| 749 |
| 750 def generate_execution_code(self, code): |
| 751 # Note: all def function specialization are wrapped in PyCFunction |
| 752 # nodes in the self.__signatures__ dictnode. |
| 753 for default in self.defaults: |
| 754 if default is not None: |
| 755 default.generate_evaluation_code(code) |
| 756 |
| 757 if self.py_func: |
| 758 self.defaults_tuple.generate_evaluation_code(code) |
| 759 self.code_object.generate_evaluation_code(code) |
| 760 |
| 761 for stat in self.stats: |
| 762 code.mark_pos(stat.pos) |
| 763 if isinstance(stat, ExprNodes.ExprNode): |
| 764 stat.generate_evaluation_code(code) |
| 765 else: |
| 766 stat.generate_execution_code(code) |
| 767 |
| 768 if self.__signatures__: |
| 769 self.resulting_fused_function.generate_evaluation_code(code) |
| 770 |
| 771 code.putln( |
| 772 "((__pyx_FusedFunctionObject *) %s)->__signatures__ = %s;" % |
| 773 (self.resulting_fused_function.result(), |
| 774 self.__signatures__.result())) |
| 775 code.put_giveref(self.__signatures__.result()) |
| 776 |
| 777 self.fused_func_assignment.generate_execution_code(code) |
| 778 |
| 779 # Dispose of results |
| 780 self.resulting_fused_function.generate_disposal_code(code) |
| 781 self.defaults_tuple.generate_disposal_code(code) |
| 782 self.code_object.generate_disposal_code(code) |
| 783 |
| 784 for default in self.defaults: |
| 785 if default is not None: |
| 786 default.generate_disposal_code(code) |
| 787 |
| 788 def annotate(self, code): |
| 789 for stat in self.stats: |
| 790 stat.annotate(code) |
OLD | NEW |