Index: third_party/cython/src/Cython/Build/Inline.py |
diff --git a/third_party/cython/src/Cython/Build/Inline.py b/third_party/cython/src/Cython/Build/Inline.py |
new file mode 100644 |
index 0000000000000000000000000000000000000000..fcbb6c1282a3e9768ac21e25c692d51213498300 |
--- /dev/null |
+++ b/third_party/cython/src/Cython/Build/Inline.py |
@@ -0,0 +1,304 @@ |
+import sys, os, re, inspect |
+import imp |
+ |
+try: |
+ import hashlib |
+except ImportError: |
+ import md5 as hashlib |
+ |
+from distutils.core import Distribution, Extension |
+from distutils.command.build_ext import build_ext |
+ |
+import Cython |
+from Cython.Compiler.Main import Context, CompilationOptions, default_options |
+ |
+from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform |
+from Cython.Compiler.TreeFragment import parse_from_strings |
+from Cython.Build.Dependencies import strip_string_literals, cythonize, cached_function |
+from Cython.Compiler import Pipeline |
+from Cython.Utils import get_cython_cache_dir |
+import cython as cython_module |
+ |
+# A utility function to convert user-supplied ASCII strings to unicode. |
+if sys.version_info[0] < 3: |
+ def to_unicode(s): |
+ if not isinstance(s, unicode): |
+ return s.decode('ascii') |
+ else: |
+ return s |
+else: |
+ to_unicode = lambda x: x |
+ |
+ |
+class AllSymbols(CythonTransform, SkipDeclarations): |
+ def __init__(self): |
+ CythonTransform.__init__(self, None) |
+ self.names = set() |
+ def visit_NameNode(self, node): |
+ self.names.add(node.name) |
+ |
+@cached_function |
+def unbound_symbols(code, context=None): |
+ code = to_unicode(code) |
+ if context is None: |
+ context = Context([], default_options) |
+ from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform |
+ tree = parse_from_strings('(tree fragment)', code) |
+ for phase in Pipeline.create_pipeline(context, 'pyx'): |
+ if phase is None: |
+ continue |
+ tree = phase(tree) |
+ if isinstance(phase, AnalyseDeclarationsTransform): |
+ break |
+ symbol_collector = AllSymbols() |
+ symbol_collector(tree) |
+ unbound = [] |
+ try: |
+ import builtins |
+ except ImportError: |
+ import __builtin__ as builtins |
+ for name in symbol_collector.names: |
+ if not tree.scope.lookup(name) and not hasattr(builtins, name): |
+ unbound.append(name) |
+ return unbound |
+ |
+def unsafe_type(arg, context=None): |
+ py_type = type(arg) |
+ if py_type is int: |
+ return 'long' |
+ else: |
+ return safe_type(arg, context) |
+ |
+def safe_type(arg, context=None): |
+ py_type = type(arg) |
+ if py_type in [list, tuple, dict, str]: |
+ return py_type.__name__ |
+ elif py_type is complex: |
+ return 'double complex' |
+ elif py_type is float: |
+ return 'double' |
+ elif py_type is bool: |
+ return 'bint' |
+ elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray): |
+ return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim) |
+ else: |
+ for base_type in py_type.mro(): |
+ if base_type.__module__ in ('__builtin__', 'builtins'): |
+ return 'object' |
+ module = context.find_module(base_type.__module__, need_pxd=False) |
+ if module: |
+ entry = module.lookup(base_type.__name__) |
+ if entry.is_type: |
+ return '%s.%s' % (base_type.__module__, base_type.__name__) |
+ return 'object' |
+ |
+def _get_build_extension(): |
+ dist = Distribution() |
+ # Ensure the build respects distutils configuration by parsing |
+ # the configuration files |
+ config_files = dist.find_config_files() |
+ dist.parse_config_files(config_files) |
+ build_extension = build_ext(dist) |
+ build_extension.finalize_options() |
+ return build_extension |
+ |
+@cached_function |
+def _create_context(cython_include_dirs): |
+ return Context(list(cython_include_dirs), default_options) |
+ |
+def cython_inline(code, |
+ get_type=unsafe_type, |
+ lib_dir=os.path.join(get_cython_cache_dir(), 'inline'), |
+ cython_include_dirs=['.'], |
+ force=False, |
+ quiet=False, |
+ locals=None, |
+ globals=None, |
+ **kwds): |
+ if get_type is None: |
+ get_type = lambda x: 'object' |
+ code = to_unicode(code) |
+ orig_code = code |
+ code, literals = strip_string_literals(code) |
+ code = strip_common_indent(code) |
+ ctx = _create_context(tuple(cython_include_dirs)) |
+ if locals is None: |
+ locals = inspect.currentframe().f_back.f_back.f_locals |
+ if globals is None: |
+ globals = inspect.currentframe().f_back.f_back.f_globals |
+ try: |
+ for symbol in unbound_symbols(code): |
+ if symbol in kwds: |
+ continue |
+ elif symbol in locals: |
+ kwds[symbol] = locals[symbol] |
+ elif symbol in globals: |
+ kwds[symbol] = globals[symbol] |
+ else: |
+ print("Couldn't find ", symbol) |
+ except AssertionError: |
+ if not quiet: |
+ # Parsing from strings not fully supported (e.g. cimports). |
+ print("Could not parse code as a string (to extract unbound symbols).") |
+ cimports = [] |
+ for name, arg in kwds.items(): |
+ if arg is cython_module: |
+ cimports.append('\ncimport cython as %s' % name) |
+ del kwds[name] |
+ arg_names = kwds.keys() |
+ arg_names.sort() |
+ arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names]) |
+ key = orig_code, arg_sigs, sys.version_info, sys.executable, Cython.__version__ |
+ module_name = "_cython_inline_" + hashlib.md5(str(key).encode('utf-8')).hexdigest() |
+ |
+ if module_name in sys.modules: |
+ module = sys.modules[module_name] |
+ |
+ else: |
+ build_extension = None |
+ if cython_inline.so_ext is None: |
+ # Figure out and cache current extension suffix |
+ build_extension = _get_build_extension() |
+ cython_inline.so_ext = build_extension.get_ext_filename('') |
+ |
+ module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext) |
+ |
+ if not os.path.exists(lib_dir): |
+ os.makedirs(lib_dir) |
+ if force or not os.path.isfile(module_path): |
+ cflags = [] |
+ c_include_dirs = [] |
+ qualified = re.compile(r'([.\w]+)[.]') |
+ for type, _ in arg_sigs: |
+ m = qualified.match(type) |
+ if m: |
+ cimports.append('\ncimport %s' % m.groups()[0]) |
+ # one special case |
+ if m.groups()[0] == 'numpy': |
+ import numpy |
+ c_include_dirs.append(numpy.get_include()) |
+ # cflags.append('-Wno-unused') |
+ module_body, func_body = extract_func_code(code) |
+ params = ', '.join(['%s %s' % a for a in arg_sigs]) |
+ module_code = """ |
+%(module_body)s |
+%(cimports)s |
+def __invoke(%(params)s): |
+%(func_body)s |
+ """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body } |
+ for key, value in literals.items(): |
+ module_code = module_code.replace(key, value) |
+ pyx_file = os.path.join(lib_dir, module_name + '.pyx') |
+ fh = open(pyx_file, 'w') |
+ try: |
+ fh.write(module_code) |
+ finally: |
+ fh.close() |
+ extension = Extension( |
+ name = module_name, |
+ sources = [pyx_file], |
+ include_dirs = c_include_dirs, |
+ extra_compile_args = cflags) |
+ if build_extension is None: |
+ build_extension = _get_build_extension() |
+ build_extension.extensions = cythonize([extension], include_path=cython_include_dirs, quiet=quiet) |
+ build_extension.build_temp = os.path.dirname(pyx_file) |
+ build_extension.build_lib = lib_dir |
+ build_extension.run() |
+ |
+ module = imp.load_dynamic(module_name, module_path) |
+ |
+ arg_list = [kwds[arg] for arg in arg_names] |
+ return module.__invoke(*arg_list) |
+ |
+# Cached suffix used by cython_inline above. None should get |
+# overridden with actual value upon the first cython_inline invocation |
+cython_inline.so_ext = None |
+ |
+non_space = re.compile('[^ ]') |
+def strip_common_indent(code): |
+ min_indent = None |
+ lines = code.split('\n') |
+ for line in lines: |
+ match = non_space.search(line) |
+ if not match: |
+ continue # blank |
+ indent = match.start() |
+ if line[indent] == '#': |
+ continue # comment |
+ elif min_indent is None or min_indent > indent: |
+ min_indent = indent |
+ for ix, line in enumerate(lines): |
+ match = non_space.search(line) |
+ if not match or line[indent] == '#': |
+ continue |
+ else: |
+ lines[ix] = line[min_indent:] |
+ return '\n'.join(lines) |
+ |
+module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))') |
+def extract_func_code(code): |
+ module = [] |
+ function = [] |
+ current = function |
+ code = code.replace('\t', ' ') |
+ lines = code.split('\n') |
+ for line in lines: |
+ if not line.startswith(' '): |
+ if module_statement.match(line): |
+ current = module |
+ else: |
+ current = function |
+ current.append(line) |
+ return '\n'.join(module), ' ' + '\n '.join(function) |
+ |
+ |
+ |
+try: |
+ from inspect import getcallargs |
+except ImportError: |
+ def getcallargs(func, *arg_values, **kwd_values): |
+ all = {} |
+ args, varargs, kwds, defaults = inspect.getargspec(func) |
+ if varargs is not None: |
+ all[varargs] = arg_values[len(args):] |
+ for name, value in zip(args, arg_values): |
+ all[name] = value |
+ for name, value in kwd_values.items(): |
+ if name in args: |
+ if name in all: |
+ raise TypeError("Duplicate argument %s" % name) |
+ all[name] = kwd_values.pop(name) |
+ if kwds is not None: |
+ all[kwds] = kwd_values |
+ elif kwd_values: |
+ raise TypeError("Unexpected keyword arguments: %s" % kwd_values.keys()) |
+ if defaults is None: |
+ defaults = () |
+ first_default = len(args) - len(defaults) |
+ for ix, name in enumerate(args): |
+ if name not in all: |
+ if ix >= first_default: |
+ all[name] = defaults[ix - first_default] |
+ else: |
+ raise TypeError("Missing argument: %s" % name) |
+ return all |
+ |
+def get_body(source): |
+ ix = source.index(':') |
+ if source[:5] == 'lambda': |
+ return "return %s" % source[ix+1:] |
+ else: |
+ return source[ix+1:] |
+ |
+# Lots to be done here... It would be especially cool if compiled functions |
+# could invoke each other quickly. |
+class RuntimeCompiledFunction(object): |
+ |
+ def __init__(self, f): |
+ self._f = f |
+ self._body = get_body(inspect.getsource(f)) |
+ |
+ def __call__(self, *args, **kwds): |
+ all = getcallargs(self._f, *args, **kwds) |
+ return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all) |